1#   Copyright 2020 The PyMC Developers
2#
3#   Licensed under the Apache License, Version 2.0 (the "License");
4#   you may not use this file except in compliance with the License.
5#   You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#   Unless required by applicable law or agreed to in writing, software
10#   distributed under the License is distributed on an "AS IS" BASIS,
11#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#   See the License for the specific language governing permissions and
13#   limitations under the License.
14
15"""Functions for MCMC sampling."""
16
17import collections.abc as abc
18import logging
19import pickle
20import sys
21import time
22import warnings
23
24from collections import defaultdict
25from copy import copy, deepcopy
26from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
27
28import arviz
29import numpy as np
30import packaging
31import theano.gradient as tg
32import xarray
33
34from arviz import InferenceData
35from fastprogress.fastprogress import progress_bar
36
37import pymc3 as pm
38
39from pymc3.backends.base import BaseTrace, MultiTrace
40from pymc3.backends.ndarray import NDArray
41from pymc3.distributions.distribution import draw_values
42from pymc3.distributions.posterior_predictive import fast_sample_posterior_predictive
43from pymc3.exceptions import IncorrectArgumentsError, SamplingError
44from pymc3.model import Model, Point, all_continuous, modelcontext
45from pymc3.parallel_sampling import Draw, _cpu_count
46from pymc3.step_methods import (
47    NUTS,
48    PGBART,
49    BinaryGibbsMetropolis,
50    BinaryMetropolis,
51    CategoricalGibbsMetropolis,
52    CompoundStep,
53    DEMetropolis,
54    HamiltonianMC,
55    Metropolis,
56    Slice,
57)
58from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
59from pymc3.step_methods.hmc import quadpotential
60from pymc3.util import (
61    chains_and_samples,
62    check_start_vals,
63    dataset_to_point_list,
64    get_default_varnames,
65    get_untransformed_name,
66    is_transformed_name,
67    update_start_vals,
68)
69from pymc3.vartypes import discrete_types
70
71sys.setrecursionlimit(10000)
72
73__all__ = [
74    "sample",
75    "iter_sample",
76    "sample_posterior_predictive",
77    "sample_posterior_predictive_w",
78    "init_nuts",
79    "sample_prior_predictive",
80    "fast_sample_posterior_predictive",
81]
82
83STEP_METHODS = (
84    NUTS,
85    HamiltonianMC,
86    Metropolis,
87    BinaryMetropolis,
88    BinaryGibbsMetropolis,
89    Slice,
90    CategoricalGibbsMetropolis,
91    PGBART,
92)
93Step = Union[BlockedStep, CompoundStep]
94
95ArrayLike = Union[np.ndarray, List[float]]
96PointType = Dict[str, np.ndarray]
97PointList = List[PointType]
98Backend = Union[BaseTrace, MultiTrace, NDArray]
99
100_log = logging.getLogger("pymc3")
101
102
103def instantiate_steppers(
104    _model, steps: List[Step], selected_steps, step_kwargs=None
105) -> Union[Step, List[Step]]:
106    """Instantiate steppers assigned to the model variables.
107
108    This function is intended to be called automatically from ``sample()``, but
109    may be called manually.
110
111    Parameters
112    ----------
113    model : Model object
114        A fully-specified model object; legacy argument -- ignored
115    steps : list
116        A list of zero or more step function instances that have been assigned to some subset of
117        the model's parameters.
118    selected_steps : dict
119        A dictionary that maps a step method class to a list of zero or more model variables.
120    step_kwargs : dict
121        Parameters for the samplers. Keys are the lower case names of
122        the step method, values a dict of arguments. Defaults to None.
123
124    Returns
125    -------
126    methods : list or step
127        List of step methods associated with the model's variables, or step method
128        if there is only one.
129    """
130    if step_kwargs is None:
131        step_kwargs = {}
132
133    used_keys = set()
134    for step_class, vars in selected_steps.items():
135        if vars:
136            args = step_kwargs.get(step_class.name, {})
137            used_keys.add(step_class.name)
138            step = step_class(vars=vars, **args)
139            steps.append(step)
140
141    unused_args = set(step_kwargs).difference(used_keys)
142    if unused_args:
143        raise ValueError("Unused step method arguments: %s" % unused_args)
144
145    if len(steps) == 1:
146        return steps[0]
147
148    return steps
149
150
151def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None):
152    """Assign model variables to appropriate step methods.
153
154    Passing a specified model will auto-assign its constituent stochastic
155    variables to step methods based on the characteristics of the variables.
156    This function is intended to be called automatically from ``sample()``, but
157    may be called manually. Each step method passed should have a
158    ``competence()`` method that returns an ordinal competence value
159    corresponding to the variable passed to it. This value quantifies the
160    appropriateness of the step method for sampling the variable.
161
162    Parameters
163    ----------
164    model : Model object
165        A fully-specified model object
166    step : step function or vector of step functions
167        One or more step functions that have been assigned to some subset of
168        the model's parameters. Defaults to ``None`` (no assigned variables).
169    methods : vector of step method classes
170        The set of step methods from which the function may choose. Defaults
171        to the main step methods provided by PyMC3.
172    step_kwargs : dict
173        Parameters for the samplers. Keys are the lower case names of
174        the step method, values a dict of arguments.
175
176    Returns
177    -------
178    methods : list
179        List of step methods associated with the model's variables.
180    """
181    steps = []
182    assigned_vars = set()
183
184    if step is not None:
185        try:
186            steps += list(step)
187        except TypeError:
188            steps.append(step)
189        for step in steps:
190            try:
191                assigned_vars = assigned_vars.union(set(step.vars))
192            except AttributeError:
193                for method in step.methods:
194                    assigned_vars = assigned_vars.union(set(method.vars))
195
196    # Use competence classmethods to select step methods for remaining
197    # variables
198    selected_steps = defaultdict(list)
199    for var in model.free_RVs:
200        if var not in assigned_vars:
201            # determine if a gradient can be computed
202            has_gradient = var.dtype not in discrete_types
203            if has_gradient:
204                try:
205                    tg.grad(model.logpt, var)
206                except (AttributeError, NotImplementedError, tg.NullTypeGradError):
207                    has_gradient = False
208            # select the best method
209            selected = max(
210                methods,
211                key=lambda method, var=var, has_gradient=has_gradient: method._competence(
212                    var, has_gradient
213                ),
214            )
215            selected_steps[selected].append(var)
216
217    return instantiate_steppers(model, steps, selected_steps, step_kwargs)
218
219
220def _print_step_hierarchy(s: Step, level=0) -> None:
221    if isinstance(s, CompoundStep):
222        _log.info(">" * level + "CompoundStep")
223        for i in s.methods:
224            _print_step_hierarchy(i, level + 1)
225    else:
226        varnames = ", ".join(
227            [
228                get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name
229                for v in s.vars
230            ]
231        )
232        _log.info(">" * level + f"{s.__class__.__name__}: [{varnames}]")
233
234
235def sample(
236    draws=1000,
237    step=None,
238    init="auto",
239    n_init=200000,
240    start=None,
241    trace=None,
242    chain_idx=0,
243    chains=None,
244    cores=None,
245    tune=1000,
246    progressbar=True,
247    model=None,
248    random_seed=None,
249    discard_tuned_samples=True,
250    compute_convergence_checks=True,
251    callback=None,
252    jitter_max_retries=10,
253    *,
254    return_inferencedata=None,
255    idata_kwargs: dict = None,
256    mp_ctx=None,
257    pickle_backend: str = "pickle",
258    **kwargs,
259):
260    r"""Draw samples from the posterior using the given step methods.
261
262    Multiple step methods are supported via compound step methods.
263
264    Parameters
265    ----------
266    draws : int
267        The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
268        by default. See ``discard_tuned_samples``.
269    init : str
270        Initialization method to use for auto-assigned NUTS samplers.
271
272        * auto: Choose a default initialization method automatically.
273          Currently, this is ``jitter+adapt_diag``, but this can change in the future.
274          If you depend on the exact behaviour, choose an initialization method explicitly.
275        * adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
276          variance of the tuning samples. All chains use the test value (usually the prior mean)
277          as starting point.
278        * jitter+adapt_diag: Same as ``adapt_diag``, but add uniform jitter in [-1, 1] to the
279          starting point in each chain.
280        * advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
281          sample variance of the tuning samples.
282        * advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
283          on the variance of the gradients during tuning. This is **experimental** and might be
284          removed in a future release.
285        * advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
286        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
287        * map: Use the MAP as starting point. This is discouraged.
288        * adapt_full: Adapt a dense mass matrix using the sample covariances
289
290    step : function or iterable of functions
291        A step function or collection of functions. If there are variables without step methods,
292        step methods for those variables will be assigned automatically.  By default the NUTS step
293        method will be used, if appropriate to the model; this is a good default for beginning
294        users.
295    n_init : int
296        Number of iterations of initializer. Only works for 'ADVI' init methods.
297    start : dict, or array of dict
298        Starting point in parameter space (or partial point)
299        Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not
300        (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
301        overwrite the default.
302    trace : backend, list, or MultiTrace
303        This should be a backend instance, a list of variables to track, or a MultiTrace object
304        with past values. If a MultiTrace object is given, it must contain samples for the chain
305        number ``chain``. If None or a list of variables, the NDArray backend is used.
306    chain_idx : int
307        Chain number used to store sample in backend. If ``chains`` is greater than one, chain
308        numbers will start here.
309    chains : int
310        The number of chains to sample. Running independent chains is important for some
311        convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
312        then set to either ``cores`` or 2, whichever is larger.
313    cores : int
314        The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
315        system, but at most 4.
316    tune : int
317        Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or
318        similar during tuning. Tuning samples will be drawn in addition to the number specified in
319        the ``draws`` argument, and will be discarded unless ``discard_tuned_samples`` is set to
320        False.
321    progressbar : bool, optional default=True
322        Whether or not to display a progress bar in the command line. The bar shows the percentage
323        of completion, the sampling speed in samples per second (SPS), and the estimated remaining
324        time until completion ("expected time of arrival"; ETA).
325    model : Model (optional if in ``with`` context)
326    random_seed : int or list of ints
327        A list is accepted if ``cores`` is greater than one.
328    discard_tuned_samples : bool
329        Whether to discard posterior samples of the tune interval.
330    compute_convergence_checks : bool, default=True
331        Whether to compute sampler statistics like Gelman-Rubin and ``effective_n``.
332    callback : function, default=None
333        A function which gets called for every sample from the trace of a chain. The function is
334        called with the trace and the current draw and will contain all samples for a single trace.
335        the ``draw.chain`` argument can be used to determine which of the active chains the sample
336        is drawn from.
337        Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
338    jitter_max_retries : int
339        Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
340        that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
341        init methods.
342    return_inferencedata : bool, default=False
343        Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
344        Defaults to `False`, but we'll switch to `True` in an upcoming release.
345    idata_kwargs : dict, optional
346        Keyword arguments for :func:`arviz:arviz.from_pymc3`
347    mp_ctx : multiprocessing.context.BaseContent
348        A multiprocessing context for parallel sampling. See multiprocessing
349        documentation for details.
350    pickle_backend : str
351        One of `'pickle'` or `'dill'`. The library used to pickle models
352        in parallel sampling if the multiprocessing context is not of type
353        `fork`.
354
355    Returns
356    -------
357    trace : pymc3.backends.base.MultiTrace or arviz.InferenceData
358        A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples.
359
360    Notes
361    -----
362    Optional keyword arguments can be passed to ``sample`` to be delivered to the
363    ``step_method``\ s used during sampling.
364
365    If your model uses only one step method, you can address step method kwargs
366    directly. In particular, the NUTS step method has several options including:
367
368        * target_accept : float in [0, 1]. The step size is tuned such that we
369          approximate this acceptance rate. Higher values like 0.9 or 0.95 often
370          work better for problematic posteriors
371        * max_treedepth : The maximum depth of the trajectory tree
372        * step_scale : float, default 0.25
373          The initial guess for the step size scaled down by :math:`1/n**(1/4)`
374
375    If your model uses multiple step methods, aka a Compound Step, then you have
376    two ways to address arguments to each step method:
377
378    A. If you let ``sample()`` automatically assign the ``step_method``\ s,
379       and you can correctly anticipate what they will be, then you can wrap
380       step method kwargs in a dict and pass that to sample() with a kwarg set
381       to the name of the step method.
382       e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
383       you could send:
384
385       1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
386       2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}
387
388       Note that available names are:
389
390        ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
391        ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
392        ``DEMetropolis``, ``DEMetropolisZ``, ``slice``
393
394    B. If you manually declare the ``step_method``\ s, within the ``step``
395       kwarg, then you can address the ``step_method`` kwargs directly.
396       e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
397       you could send ::
398
399        step=[pm.NUTS([freeRV1, freeRV2], target_accept=0.9),
400              pm.BinaryGibbsMetropolis([freeRV3], transit_p=.7)]
401
402    You can find a full list of arguments in the docstring of the step methods.
403
404    Examples
405    --------
406    .. code:: ipython
407
408        In [1]: import pymc3 as pm
409           ...: n = 100
410           ...: h = 61
411           ...: alpha = 2
412           ...: beta = 2
413
414        In [2]: with pm.Model() as model: # context management
415           ...:     p = pm.Beta("p", alpha=alpha, beta=beta)
416           ...:     y = pm.Binomial("y", n=n, p=p, observed=h)
417           ...:     trace = pm.sample()
418
419        In [3]: az.summary(trace, kind="stats")
420
421        Out[3]:
422            mean     sd  hdi_3%  hdi_97%
423        p  0.609  0.047   0.528    0.699
424    """
425    model = modelcontext(model)
426    start = deepcopy(start)
427    if start is None:
428        check_start_vals(model.test_point, model)
429    else:
430        if isinstance(start, dict):
431            update_start_vals(start, model.test_point, model)
432        else:
433            for chain_start_vals in start:
434                update_start_vals(chain_start_vals, model.test_point, model)
435        check_start_vals(start, model)
436
437    if cores is None:
438        cores = min(4, _cpu_count())
439
440    if chains is None:
441        chains = max(2, cores)
442    if isinstance(start, dict):
443        start = [start] * chains
444    if random_seed == -1:
445        random_seed = None
446    if chains == 1 and isinstance(random_seed, int):
447        random_seed = [random_seed]
448    if random_seed is None or isinstance(random_seed, int):
449        if random_seed is not None:
450            np.random.seed(random_seed)
451        random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
452    if not isinstance(random_seed, abc.Iterable):
453        raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
454
455    if not discard_tuned_samples and not return_inferencedata:
456        warnings.warn(
457            "Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
458            " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n"
459            "`pm.sample(..., return_inferencedata=True)`",
460            UserWarning,
461        )
462
463    if return_inferencedata is None:
464        v = packaging.version.parse(pm.__version__)
465        if v.release[0] > 3 or v.release[1] >= 10:  # type: ignore
466            warnings.warn(
467                "In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. "
468                "You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.",
469                FutureWarning,
470                stacklevel=2,
471            )
472        # set the default
473        return_inferencedata = False
474
475    if start is not None:
476        for start_vals in start:
477            _check_start_shape(model, start_vals)
478
479    # small trace warning
480    if draws == 0:
481        msg = "Tuning was enabled throughout the whole trace."
482        _log.warning(msg)
483    elif draws < 500:
484        msg = "Only %s samples in chain." % draws
485        _log.warning(msg)
486
487    draws += tune
488
489    if model.ndim == 0:
490        raise ValueError("The model does not contain any free variables.")
491
492    if step is None and init is not None and all_continuous(model.vars):
493        try:
494            # By default, try to use NUTS
495            _log.info("Auto-assigning NUTS sampler...")
496            start_, step = init_nuts(
497                init=init,
498                chains=chains,
499                n_init=n_init,
500                model=model,
501                random_seed=random_seed,
502                progressbar=progressbar,
503                jitter_max_retries=jitter_max_retries,
504                **kwargs,
505            )
506            if start is None:
507                start = start_
508                check_start_vals(start, model)
509        except (AttributeError, NotImplementedError, tg.NullTypeGradError):
510            # gradient computation failed
511            _log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.")
512            _log.debug("Exception in init nuts", exec_info=True)
513            step = assign_step_methods(model, step, step_kwargs=kwargs)
514    else:
515        step = assign_step_methods(model, step, step_kwargs=kwargs)
516
517    if isinstance(step, list):
518        step = CompoundStep(step)
519    if start is None:
520        start = {}
521    if isinstance(start, dict):
522        start = [start] * chains
523
524    sample_args = {
525        "draws": draws,
526        "step": step,
527        "start": start,
528        "trace": trace,
529        "chain": chain_idx,
530        "chains": chains,
531        "tune": tune,
532        "progressbar": progressbar,
533        "model": model,
534        "random_seed": random_seed,
535        "cores": cores,
536        "callback": callback,
537        "discard_tuned_samples": discard_tuned_samples,
538    }
539    parallel_args = {
540        "pickle_backend": pickle_backend,
541        "mp_ctx": mp_ctx,
542    }
543
544    sample_args.update(kwargs)
545
546    has_population_samplers = np.any(
547        [
548            isinstance(m, PopulationArrayStepShared)
549            for m in (step.methods if isinstance(step, CompoundStep) else [step])
550        ]
551    )
552
553    parallel = cores > 1 and chains > 1 and not has_population_samplers
554    t_start = time.time()
555    if parallel:
556        _log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
557        _print_step_hierarchy(step)
558        try:
559            trace = _mp_sample(**sample_args, **parallel_args)
560        except pickle.PickleError:
561            _log.warning("Could not pickle model, sampling singlethreaded.")
562            _log.debug("Pickling error:", exec_info=True)
563            parallel = False
564        except AttributeError as e:
565            if not str(e).startswith("AttributeError: Can't pickle"):
566                raise
567            _log.warning("Could not pickle model, sampling singlethreaded.")
568            _log.debug("Pickling error:", exec_info=True)
569            parallel = False
570    if not parallel:
571        if has_population_samplers:
572            has_demcmc = np.any(
573                [
574                    isinstance(m, DEMetropolis)
575                    for m in (step.methods if isinstance(step, CompoundStep) else [step])
576                ]
577            )
578            _log.info(f"Population sampling ({chains} chains)")
579            if has_demcmc and chains < 3:
580                raise ValueError(
581                    "DEMetropolis requires at least 3 chains. "
582                    "For this {}-dimensional model you should use ≥{} chains".format(
583                        model.ndim, model.ndim + 1
584                    )
585                )
586            if has_demcmc and chains <= model.ndim:
587                warnings.warn(
588                    "DEMetropolis should be used with more chains than dimensions! "
589                    "(The model has {} dimensions.)".format(model.ndim),
590                    UserWarning,
591                )
592            _print_step_hierarchy(step)
593            trace = _sample_population(parallelize=cores > 1, **sample_args)
594        else:
595            _log.info(f"Sequential sampling ({chains} chains in 1 job)")
596            _print_step_hierarchy(step)
597            trace = _sample_many(**sample_args)
598
599    t_sampling = time.time() - t_start
600    # count the number of tune/draw iterations that happened
601    # ideally via the "tune" statistic, but not all samplers record it!
602    if "tune" in trace.stat_names:
603        stat = trace.get_sampler_stats("tune", chains=0)
604        # when CompoundStep is used, the stat is 2 dimensional!
605        if len(stat.shape) == 2:
606            stat = stat[:, 0]
607        stat = tuple(stat)
608        n_tune = stat.count(True)
609        n_draws = stat.count(False)
610    else:
611        # these may be wrong when KeyboardInterrupt happened, but they're better than nothing
612        n_tune = min(tune, len(trace))
613        n_draws = max(0, len(trace) - n_tune)
614
615    if discard_tuned_samples:
616        trace = trace[n_tune:]
617
618    # save metadata in SamplerReport
619    trace.report._n_tune = n_tune
620    trace.report._n_draws = n_draws
621    trace.report._t_sampling = t_sampling
622
623    if "variable_inclusion" in trace.stat_names:
624        variable_inclusion = np.stack(trace.get_sampler_stats("variable_inclusion")).mean(0)
625        trace.report.variable_importance = variable_inclusion / variable_inclusion.sum()
626
627    n_chains = len(trace.chains)
628    _log.info(
629        f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
630        f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
631        f"took {trace.report.t_sampling:.0f} seconds."
632    )
633
634    idata = None
635    if compute_convergence_checks or return_inferencedata:
636        ikwargs = dict(model=model, save_warmup=not discard_tuned_samples)
637        if idata_kwargs:
638            ikwargs.update(idata_kwargs)
639        idata = arviz.from_pymc3(trace, **ikwargs)
640
641    if compute_convergence_checks:
642        if draws - tune < 100:
643            warnings.warn("The number of samples is too small to check convergence reliably.")
644        else:
645            trace.report._run_convergence_checks(idata, model)
646    trace.report._log_summary()
647
648    if return_inferencedata:
649        return idata
650    else:
651        return trace
652
653
654def _check_start_shape(model, start):
655    if not isinstance(start, dict):
656        raise TypeError("start argument must be a dict or an array-like of dicts")
657    e = ""
658    for var in model.vars:
659        if var.name in start.keys():
660            var_shape = var.shape.tag.test_value
661            start_var_shape = np.shape(start[var.name])
662            if start_var_shape:
663                if not np.array_equal(var_shape, start_var_shape):
664                    e += "\nExpected shape {} for var '{}', got: {}".format(
665                        tuple(var_shape), var.name, start_var_shape
666                    )
667            # if start var has no shape
668            else:
669                # if model var has a specified shape
670                if var_shape.size > 0:
671                    e += "\nExpected shape {} for var " "'{}', got scalar {}".format(
672                        tuple(var_shape), var.name, start[var.name]
673                    )
674
675    if e != "":
676        raise ValueError(f"Bad shape for start argument:{e}")
677
678
679def _sample_many(
680    draws,
681    chain: int,
682    chains: int,
683    start: list,
684    random_seed: list,
685    step,
686    callback=None,
687    **kwargs,
688):
689    """Samples all chains sequentially.
690
691    Parameters
692    ----------
693    draws: int
694        The number of samples to draw
695    chain: int
696        Number of the first chain in the sequence.
697    chains: int
698        Total number of chains to sample.
699    start: list
700        Starting points for each chain
701    random_seed: list
702        A list of seeds, one for each chain
703    step: function
704        Step function
705
706    Returns
707    -------
708    trace: MultiTrace
709        Contains samples of all chains
710    """
711    traces: List[Backend] = []
712    for i in range(chains):
713        trace = _sample(
714            draws=draws,
715            chain=chain + i,
716            start=start[i],
717            step=step,
718            random_seed=random_seed[i],
719            callback=callback,
720            **kwargs,
721        )
722        if trace is None:
723            if len(traces) == 0:
724                raise ValueError("Sampling stopped before a sample was created.")
725            else:
726                break
727        elif len(trace) < draws:
728            if len(traces) == 0:
729                traces.append(trace)
730            break
731        else:
732            traces.append(trace)
733    return MultiTrace(traces)
734
735
736def _sample_population(
737    draws: int,
738    chain: int,
739    chains: int,
740    start,
741    random_seed,
742    step,
743    tune,
744    model,
745    progressbar: bool = True,
746    parallelize=False,
747    **kwargs,
748):
749    """Performs sampling of a population of chains using the ``PopulationStepper``.
750
751    Parameters
752    ----------
753    draws : int
754        The number of samples to draw
755    chain : int
756        The number of the first chain in the population
757    chains : int
758        The total number of chains in the population
759    start : list
760        Start points for each chain
761    random_seed : int or list of ints, optional
762        A list is accepted if more if ``cores`` is greater than one.
763    step : function
764        Step function (should be or contain a population step method)
765    tune : int, optional
766        Number of iterations to tune, if applicable (defaults to None)
767    model : Model (optional if in ``with`` context)
768    progressbar : bool
769        Show progress bars? (defaults to True)
770    parallelize : bool
771        Setting for multiprocess parallelization
772
773    Returns
774    -------
775    trace : MultiTrace
776        Contains samples of all chains
777    """
778    sampling = _prepare_iter_population(
779        draws,
780        [chain + c for c in range(chains)],
781        step,
782        start,
783        parallelize,
784        tune=tune,
785        model=model,
786        random_seed=random_seed,
787        progressbar=progressbar,
788    )
789
790    if progressbar:
791        sampling = progress_bar(sampling, total=draws, display=progressbar)
792
793    latest_traces = None
794    for it, traces in enumerate(sampling):
795        latest_traces = traces
796    return MultiTrace(latest_traces)
797
798
799def _sample(
800    chain: int,
801    progressbar: bool,
802    random_seed,
803    start,
804    draws: int,
805    step=None,
806    trace=None,
807    tune=None,
808    model: Optional[Model] = None,
809    callback=None,
810    **kwargs,
811):
812    """Main iteration for singleprocess sampling.
813
814    Multiple step methods are supported via compound step methods.
815
816    Parameters
817    ----------
818    chain : int
819        Number of the chain that the samples will belong to.
820    progressbar : bool
821        Whether or not to display a progress bar in the command line. The bar shows the percentage
822        of completion, the sampling speed in samples per second (SPS), and the estimated remaining
823        time until completion ("expected time of arrival"; ETA).
824    random_seed : int or list of ints
825        A list is accepted if ``cores`` is greater than one.
826    start : dict
827        Starting point in parameter space (or partial point)
828    draws : int
829        The number of samples to draw
830    step : function
831        Step function
832    trace : backend, list, or MultiTrace
833        This should be a backend instance, a list of variables to track, or a MultiTrace object
834        with past values. If a MultiTrace object is given, it must contain samples for the chain
835        number ``chain``. If None or a list of variables, the NDArray backend is used.
836    tune : int, optional
837        Number of iterations to tune, if applicable (defaults to None)
838    model : Model (optional if in ``with`` context)
839
840    Returns
841    -------
842    strace : pymc3.backends.base.BaseTrace
843        A ``BaseTrace`` object that contains the samples for this chain.
844    """
845    skip_first = kwargs.get("skip_first", 0)
846
847    sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback)
848    _pbar_data = {"chain": chain, "divergences": 0}
849    _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
850    if progressbar:
851        sampling = progress_bar(sampling, total=draws, display=progressbar)
852        sampling.comment = _desc.format(**_pbar_data)
853    try:
854        strace = None
855        for it, (strace, diverging) in enumerate(sampling):
856            if it >= skip_first and diverging:
857                _pbar_data["divergences"] += 1
858                if progressbar:
859                    sampling.comment = _desc.format(**_pbar_data)
860    except KeyboardInterrupt:
861        pass
862    return strace
863
864
865def iter_sample(
866    draws: int,
867    step,
868    start: Optional[Dict[Any, Any]] = None,
869    trace=None,
870    chain=0,
871    tune: Optional[int] = None,
872    model: Optional[Model] = None,
873    random_seed: Optional[Union[int, List[int]]] = None,
874    callback=None,
875):
876    """Generate a trace on each iteration using the given step method.
877
878    Multiple step methods ared supported via compound step methods.  Returns the
879    amount of time taken.
880
881    Parameters
882    ----------
883    draws : int
884        The number of samples to draw
885    step : function
886        Step function
887    start : dict
888        Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
889        there is a trace provided and model.test_point if not (defaults to empty dict)
890    trace : backend, list, or MultiTrace
891        This should be a backend instance, a list of variables to track, or a MultiTrace object
892        with past values. If a MultiTrace object is given, it must contain samples for the chain
893        number ``chain``. If None or a list of variables, the NDArray backend is used.
894    chain : int, optional
895        Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
896        will start here.
897    tune : int, optional
898        Number of iterations to tune, if applicable (defaults to None)
899    model : Model (optional if in ``with`` context)
900    random_seed : int or list of ints, optional
901        A list is accepted if more if ``cores`` is greater than one.
902    callback :
903        A function which gets called for every sample from the trace of a chain. The function is
904        called with the trace and the current draw and will contain all samples for a single trace.
905        the ``draw.chain`` argument can be used to determine which of the active chains the sample
906        is drawn from.
907        Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
908
909    Yields
910    ------
911    trace : MultiTrace
912        Contains all samples up to the current iteration
913
914    Examples
915    --------
916    ::
917
918        for trace in iter_sample(500, step):
919            ...
920    """
921    sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback)
922    for i, (strace, _) in enumerate(sampling):
923        yield MultiTrace([strace[: i + 1]])
924
925
926def _iter_sample(
927    draws,
928    step,
929    start=None,
930    trace=None,
931    chain=0,
932    tune=None,
933    model=None,
934    random_seed=None,
935    callback=None,
936):
937    """Generator for sampling one chain. (Used in singleprocess sampling.)
938
939    Parameters
940    ----------
941    draws : int
942        The number of samples to draw
943    step : function
944        Step function
945    start : dict, optional
946        Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
947        there is a trace provided and model.test_point if not (defaults to empty dict)
948    trace : backend, list, MultiTrace, or None
949        This should be a backend instance, a list of variables to track, or a MultiTrace object
950        with past values. If a MultiTrace object is given, it must contain samples for the chain
951        number ``chain``. If None or a list of variables, the NDArray backend is used.
952    chain : int, optional
953        Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
954        will start here.
955    tune : int, optional
956        Number of iterations to tune, if applicable (defaults to None)
957    model : Model (optional if in ``with`` context)
958    random_seed : int or list of ints, optional
959        A list is accepted if more if ``cores`` is greater than one.
960
961    Yields
962    ------
963    strace : BaseTrace
964        The trace object containing the samples for this chain
965    diverging : bool
966        Indicates if the draw is divergent. Only available with some samplers.
967    """
968    model = modelcontext(model)
969    draws = int(draws)
970    if random_seed is not None:
971        np.random.seed(random_seed)
972    if draws < 1:
973        raise ValueError("Argument `draws` must be greater than 0.")
974
975    if start is None:
976        start = {}
977
978    strace = _choose_backend(trace, chain, model=model)
979
980    if len(strace) > 0:
981        update_start_vals(start, strace.point(-1), model)
982    else:
983        update_start_vals(start, model.test_point, model)
984
985    try:
986        step = CompoundStep(step)
987    except TypeError:
988        pass
989
990    point = Point(start, model=model)
991
992    if step.generates_stats and strace.supports_sampler_stats:
993        strace.setup(draws, chain, step.stats_dtypes)
994    else:
995        strace.setup(draws, chain)
996
997    try:
998        step.tune = bool(tune)
999        if hasattr(step, "reset_tuning"):
1000            step.reset_tuning()
1001        for i in range(draws):
1002            stats = None
1003            diverging = False
1004
1005            if i == 0 and hasattr(step, "iter_count"):
1006                step.iter_count = 0
1007            if i == tune:
1008                step = stop_tuning(step)
1009            if step.generates_stats:
1010                point, stats = step.step(point)
1011                if strace.supports_sampler_stats:
1012                    strace.record(point, stats)
1013                    diverging = i > tune and stats and stats[0].get("diverging")
1014                else:
1015                    strace.record(point)
1016            else:
1017                point = step.step(point)
1018                strace.record(point)
1019            if callback is not None:
1020                warns = getattr(step, "warnings", None)
1021                callback(
1022                    trace=strace,
1023                    draw=Draw(chain, i == draws, i, i < tune, stats, point, warns),
1024                )
1025
1026            yield strace, diverging
1027    except KeyboardInterrupt:
1028        strace.close()
1029        if hasattr(step, "warnings"):
1030            warns = step.warnings()
1031            strace._add_warnings(warns)
1032        raise
1033    except BaseException:
1034        strace.close()
1035        raise
1036    else:
1037        strace.close()
1038        if hasattr(step, "warnings"):
1039            warns = step.warnings()
1040            strace._add_warnings(warns)
1041
1042
1043class PopulationStepper:
1044    """Wraps population of step methods to step them in parallel with single or multiprocessing."""
1045
1046    def __init__(self, steppers, parallelize, progressbar=True):
1047        """Use multiprocessing to parallelize chains.
1048
1049        Falls back to sequential evaluation if multiprocessing fails.
1050
1051        In the multiprocessing mode of operation, a new process is started for each
1052        chain/stepper and Pipes are used to communicate with the main process.
1053
1054        Parameters
1055        ----------
1056        steppers : list
1057            A collection of independent step methods, one for each chain.
1058        parallelize : bool
1059            Indicates if parallelization via multiprocessing is desired.
1060        progressbar : bool
1061            Should we display a progress bar showing relative progress?
1062        """
1063        self.nchains = len(steppers)
1064        self.is_parallelized = False
1065        self._primary_ends = []
1066        self._processes = []
1067        self._steppers = steppers
1068        if parallelize:
1069            try:
1070                # configure a child process for each stepper
1071                _log.info(
1072                    "Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`."
1073                )
1074                import multiprocessing
1075
1076                for c, stepper in (
1077                    enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
1078                ):
1079                    secondary_end, primary_end = multiprocessing.Pipe()
1080                    stepper_dumps = pickle.dumps(stepper, protocol=4)
1081                    process = multiprocessing.Process(
1082                        target=self.__class__._run_secondary,
1083                        args=(c, stepper_dumps, secondary_end),
1084                        name=f"ChainWalker{c}",
1085                    )
1086                    # we want the child process to exit if the parent is terminated
1087                    process.daemon = True
1088                    # Starting the process might fail and takes time.
1089                    # By doing it in the constructor, the sampling progress bar
1090                    # will not be confused by the process start.
1091                    process.start()
1092                    self._primary_ends.append(primary_end)
1093                    self._processes.append(process)
1094                self.is_parallelized = True
1095            except Exception:
1096                _log.info(
1097                    "Population parallelization failed. "
1098                    "Falling back to sequential stepping of chains."
1099                )
1100                _log.debug("Error was: ", exec_info=True)
1101        else:
1102            _log.info(
1103                "Chains are not parallelized. You can enable this by passing "
1104                "`pm.sample(cores=n)`, where n > 1."
1105            )
1106        return super().__init__()
1107
1108    def __enter__(self):
1109        """Do nothing: processes are already started in ``__init__``."""
1110        return
1111
1112    def __exit__(self, exc_type, exc_val, exc_tb):
1113        if len(self._processes) > 0:
1114            try:
1115                for primary_end in self._primary_ends:
1116                    primary_end.send(None)
1117                for process in self._processes:
1118                    process.join(timeout=3)
1119            except Exception:
1120                _log.warning("Termination failed.")
1121        return
1122
1123    @staticmethod
1124    def _run_secondary(c, stepper_dumps, secondary_end):
1125        """This method is started on a separate process to perform stepping of a chain.
1126
1127        Parameters
1128        ----------
1129        c : int
1130            number of this chain
1131        stepper : BlockedStep
1132            a step method such as CompoundStep
1133        secondary_end : multiprocessing.connection.PipeConnection
1134            This is our connection to the main process
1135        """
1136        # re-seed each child process to make them unique
1137        np.random.seed(None)
1138        try:
1139            stepper = pickle.loads(stepper_dumps)
1140            # the stepper is not necessarily a PopulationArraySharedStep itself,
1141            # but rather a CompoundStep. PopulationArrayStepShared.population
1142            # has to be updated, therefore we identify the substeppers first.
1143            population_steppers = []
1144            for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
1145                if isinstance(sm, PopulationArrayStepShared):
1146                    population_steppers.append(sm)
1147            while True:
1148                incoming = secondary_end.recv()
1149                # receiving a None is the signal to exit
1150                if incoming is None:
1151                    break
1152                tune_stop, population = incoming
1153                if tune_stop:
1154                    stop_tuning(stepper)
1155                # forward the population to the PopulationArrayStepShared objects
1156                # This is necessary because due to the process fork, the population
1157                # object is no longer shared between the steppers.
1158                for popstep in population_steppers:
1159                    popstep.population = population
1160                update = stepper.step(population[c])
1161                secondary_end.send(update)
1162        except Exception:
1163            _log.exception(f"ChainWalker{c}")
1164        return
1165
1166    def step(self, tune_stop, population):
1167        """Step the entire population of chains.
1168
1169        Parameters
1170        ----------
1171        tune_stop : bool
1172            Indicates if the condition (i == tune) is fulfilled
1173        population : list
1174            Current Points of all chains
1175
1176        Returns
1177        -------
1178        update : list
1179            List of (Point, stats) tuples for all chains
1180        """
1181        updates = [None] * self.nchains
1182        if self.is_parallelized:
1183            for c in range(self.nchains):
1184                self._primary_ends[c].send((tune_stop, population))
1185            # Blockingly get the step outcomes
1186            for c in range(self.nchains):
1187                updates[c] = self._primary_ends[c].recv()
1188        else:
1189            for c in range(self.nchains):
1190                if tune_stop:
1191                    self._steppers[c] = stop_tuning(self._steppers[c])
1192                updates[c] = self._steppers[c].step(population[c])
1193        return updates
1194
1195
1196def _prepare_iter_population(
1197    draws: int,
1198    chains: list,
1199    step,
1200    start: list,
1201    parallelize: bool,
1202    tune=None,
1203    model=None,
1204    random_seed=None,
1205    progressbar=True,
1206):
1207    """Prepare a PopulationStepper and traces for population sampling.
1208
1209    Parameters
1210    ----------
1211    draws : int
1212        The number of samples to draw
1213    chains : list
1214        The chain numbers in the population
1215    step : function
1216        Step function (should be or contain a population step method)
1217    start : list
1218        Start points for each chain
1219    parallelize : bool
1220        Setting for multiprocess parallelization
1221    tune : int, optional
1222        Number of iterations to tune, if applicable (defaults to None)
1223    model : Model (optional if in ``with`` context)
1224    random_seed : int or list of ints, optional
1225        A list is accepted if more if ``cores`` is greater than one.
1226    progressbar : bool
1227        ``progressbar`` argument for the ``PopulationStepper``, (defaults to True)
1228
1229    Returns
1230    -------
1231    _iter_population : generator
1232        Yields traces of all chains at the same time
1233    """
1234    # chains contains the chain numbers, but for indexing we need indices...
1235    nchains = len(chains)
1236    model = modelcontext(model)
1237    draws = int(draws)
1238    if random_seed is not None:
1239        np.random.seed(random_seed)
1240    if draws < 1:
1241        raise ValueError("Argument `draws` should be above 0.")
1242
1243    # The initialization of traces, samplers and points must happen in the right order:
1244    # 1. traces are initialized and update_start_vals configures variable transforms
1245    # 2. population of points is created
1246    # 3. steppers are initialized and linked to the points object
1247    # 4. traces are configured to track the sampler stats
1248    # 5. a PopulationStepper is configured for parallelized stepping
1249
1250    # 1. prepare a BaseTrace for each chain
1251    traces = [_choose_backend(None, chain, model=model) for chain in chains]
1252    for c, strace in enumerate(traces):
1253        # initialize the trace size and variable transforms
1254        if len(strace) > 0:
1255            update_start_vals(start[c], strace.point(-1), model)
1256        else:
1257            update_start_vals(start[c], model.test_point, model)
1258
1259    # 2. create a population (points) that tracks each chain
1260    # it is updated as the chains are advanced
1261    population = [Point(start[c], model=model) for c in range(nchains)]
1262
1263    # 3. Set up the steppers
1264    steppers: List[Step] = []
1265    for c in range(nchains):
1266        # need indepenent samplers for each chain
1267        # it is important to copy the actual steppers (but not the delta_logp)
1268        if isinstance(step, CompoundStep):
1269            chainstep = CompoundStep([copy(m) for m in step.methods])
1270        else:
1271            chainstep = copy(step)
1272        # link population samplers to the shared population state
1273        for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
1274            if isinstance(sm, PopulationArrayStepShared):
1275                sm.link_population(population, c)
1276        steppers.append(chainstep)
1277
1278    # 4. configure tracking of sampler stats
1279    for c in range(nchains):
1280        if steppers[c].generates_stats and traces[c].supports_sampler_stats:
1281            traces[c].setup(draws, c, steppers[c].stats_dtypes)
1282        else:
1283            traces[c].setup(draws, c)
1284
1285    # 5. configure the PopulationStepper (expensive call)
1286    popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar)
1287
1288    # Because the preparations above are expensive, the actual iterator is
1289    # in another method. This way the progbar will not be disturbed.
1290    return _iter_population(draws, tune, popstep, steppers, traces, population)
1291
1292
1293def _iter_population(draws, tune, popstep, steppers, traces, points):
1294    """Iterate a ``PopulationStepper``.
1295
1296    Parameters
1297    ----------
1298    draws : int
1299        number of draws per chain
1300    tune : int
1301        number of tuning steps
1302    popstep : PopulationStepper
1303        the helper object for (parallelized) stepping of chains
1304    steppers : list
1305        The step methods for each chain
1306    traces : list
1307        Traces for each chain
1308    points : list
1309        population of chain states
1310
1311    Yields
1312    ------
1313    traces : list
1314        List of trace objects of the individual chains
1315    """
1316    try:
1317        with popstep:
1318            # iterate draws of all chains
1319            for i in range(draws):
1320                # this call steps all chains and returns a list of (point, stats)
1321                # the `popstep` may interact with subprocesses internally
1322                updates = popstep.step(i == tune, points)
1323
1324                # apply the update to the points and record to the traces
1325                for c, strace in enumerate(traces):
1326                    if steppers[c].generates_stats:
1327                        points[c], stats = updates[c]
1328                        if strace.supports_sampler_stats:
1329                            strace.record(points[c], stats)
1330                        else:
1331                            strace.record(points[c])
1332                    else:
1333                        points[c] = updates[c]
1334                        strace.record(points[c])
1335                # yield the state of all chains in parallel
1336                yield traces
1337    except KeyboardInterrupt:
1338        for c, strace in enumerate(traces):
1339            strace.close()
1340            if hasattr(steppers[c], "report"):
1341                steppers[c].report._finalize(strace)
1342        raise
1343    except BaseException:
1344        for c, strace in enumerate(traces):
1345            strace.close()
1346        raise
1347    else:
1348        for c, strace in enumerate(traces):
1349            strace.close()
1350            if hasattr(steppers[c], "report"):
1351                steppers[c].report._finalize(strace)
1352
1353
1354def _choose_backend(trace, chain, **kwds) -> Backend:
1355    """Selects or creates a NDArray trace backend for a particular chain.
1356
1357    Parameters
1358    ----------
1359    trace : BaseTrace, list, MultiTrace, or None
1360        This should be a BaseTrace, list of variables to track,
1361        or a MultiTrace object with past values.
1362        If a MultiTrace object is given, it must contain samples for the chain number ``chain``.
1363        If None or a list of variables, the NDArray backend is used.
1364    chain : int
1365        Number of the chain of interest.
1366    **kwds :
1367        keyword arguments to forward to the backend creation
1368
1369    Returns
1370    -------
1371    trace : BaseTrace
1372        A trace object for the selected chain
1373    """
1374    if isinstance(trace, BaseTrace):
1375        return trace
1376    if isinstance(trace, MultiTrace):
1377        return trace._straces[chain]
1378    if trace is None:
1379        return NDArray(**kwds)
1380
1381    return NDArray(vars=trace, **kwds)
1382
1383
1384def _mp_sample(
1385    draws: int,
1386    tune: int,
1387    step,
1388    chains: int,
1389    cores: int,
1390    chain: int,
1391    random_seed: list,
1392    start: list,
1393    progressbar=True,
1394    trace=None,
1395    model=None,
1396    callback=None,
1397    discard_tuned_samples=True,
1398    mp_ctx=None,
1399    pickle_backend="pickle",
1400    **kwargs,
1401):
1402    """Main iteration for multiprocess sampling.
1403
1404    Parameters
1405    ----------
1406    draws : int
1407        The number of samples to draw
1408    tune : int, optional
1409        Number of iterations to tune, if applicable (defaults to None)
1410    step : function
1411        Step function
1412    chains : int
1413        The number of chains to sample.
1414    cores : int
1415        The number of chains to run in parallel.
1416    chain : int
1417        Number of the first chain.
1418    random_seed : list of ints
1419        Random seeds for each chain.
1420    start : list
1421        Starting points for each chain.
1422    progressbar : bool
1423        Whether or not to display a progress bar in the command line.
1424    trace : BaseTrace, list, MultiTrace or None
1425        This should be a backend instance, a list of variables to track, or a MultiTrace object
1426        with past values. If a MultiTrace object is given, it must contain samples for the chain
1427        number ``chain``. If None or a list of variables, the NDArray backend is used.
1428    model : Model (optional if in ``with`` context)
1429    callback : Callable
1430        A function which gets called for every sample from the trace of a chain. The function is
1431        called with the trace and the current draw and will contain all samples for a single trace.
1432        the ``draw.chain`` argument can be used to determine which of the active chains the sample
1433        is drawn from.
1434        Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
1435
1436    Returns
1437    -------
1438    trace : pymc3.backends.base.MultiTrace
1439        A ``MultiTrace`` object that contains the samples for all chains.
1440    """
1441    import pymc3.parallel_sampling as ps
1442
1443    # We did draws += tune in pm.sample
1444    draws -= tune
1445
1446    traces = []
1447    for idx in range(chain, chain + chains):
1448        if trace is not None:
1449            strace = _choose_backend(copy(trace), idx, model=model)
1450        else:
1451            strace = _choose_backend(None, idx, model=model)
1452        # for user supply start value, fill-in missing value if the supplied
1453        # dict does not contain all parameters
1454        update_start_vals(start[idx - chain], model.test_point, model)
1455        if step.generates_stats and strace.supports_sampler_stats:
1456            strace.setup(draws + tune, idx + chain, step.stats_dtypes)
1457        else:
1458            strace.setup(draws + tune, idx + chain)
1459        traces.append(strace)
1460
1461    sampler = ps.ParallelSampler(
1462        draws,
1463        tune,
1464        chains,
1465        cores,
1466        random_seed,
1467        start,
1468        step,
1469        chain,
1470        progressbar,
1471        mp_ctx=mp_ctx,
1472        pickle_backend=pickle_backend,
1473    )
1474    try:
1475        try:
1476            with sampler:
1477                for draw in sampler:
1478                    trace = traces[draw.chain - chain]
1479                    if trace.supports_sampler_stats and draw.stats is not None:
1480                        trace.record(draw.point, draw.stats)
1481                    else:
1482                        trace.record(draw.point)
1483                    if draw.is_last:
1484                        trace.close()
1485                        if draw.warnings is not None:
1486                            trace._add_warnings(draw.warnings)
1487
1488                    if callback is not None:
1489                        callback(trace=trace, draw=draw)
1490
1491        except ps.ParallelSamplingError as error:
1492            trace = traces[error._chain - chain]
1493            trace._add_warnings(error._warnings)
1494            for trace in traces:
1495                trace.close()
1496
1497            multitrace = MultiTrace(traces)
1498            multitrace._report._log_summary()
1499            raise
1500        return MultiTrace(traces)
1501    except KeyboardInterrupt:
1502        if discard_tuned_samples:
1503            traces, length = _choose_chains(traces, tune)
1504        else:
1505            traces, length = _choose_chains(traces, 0)
1506        return MultiTrace(traces)[:length]
1507    finally:
1508        for trace in traces:
1509            trace.close()
1510
1511
1512def _choose_chains(traces, tune):
1513    """
1514    Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.
1515
1516    We get here after a ``KeyboardInterrupt``, and so the different
1517    traces have different lengths. We therefore pick the number of
1518    traces such that (number of traces) * (length of shortest trace)
1519    is maximised.
1520    """
1521    if tune is None:
1522        tune = 0
1523
1524    if not traces:
1525        return []
1526
1527    lengths = [max(0, len(trace) - tune) for trace in traces]
1528    if not sum(lengths):
1529        raise ValueError("Not enough samples to build a trace.")
1530
1531    idxs = np.argsort(lengths)
1532    l_sort = np.array(lengths)[idxs]
1533
1534    use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1])
1535    final_length = l_sort[use_until]
1536
1537    return [traces[idx] for idx in idxs[use_until:]], final_length + tune
1538
1539
1540def stop_tuning(step):
1541    """Stop tuning the current step method."""
1542    step.stop_tuning()
1543    return step
1544
1545
1546class _DefaultTrace:
1547    """
1548    Utility for collecting samples into a dictionary.
1549
1550    Name comes from its similarity to ``defaultdict``:
1551    entries are lazily created.
1552
1553    Parameters
1554    ----------
1555    samples : int
1556        The number of samples that will be collected, per variable,
1557        into the trace.
1558
1559    Attributes
1560    ----------
1561    trace_dict : Dict[str, np.ndarray]
1562        A dictionary constituting a trace.  Should be extracted
1563        after a procedure has filled the `_DefaultTrace` using the
1564        `insert()` method
1565    """
1566
1567    trace_dict: Dict[str, np.ndarray] = {}
1568    _len: Optional[int] = None
1569
1570    def __init__(self, samples: int):
1571        self._len = samples
1572        self.trace_dict = {}
1573
1574    def insert(self, k: str, v, idx: int):
1575        """
1576        Insert `v` as the value of the `idx`th sample for the variable `k`.
1577
1578        Parameters
1579        ----------
1580        k: str
1581            Name of the variable.
1582        v: anything that can go into a numpy array (including a numpy array)
1583            The value of the `idx`th sample from variable `k`
1584        ids: int
1585            The index of the sample we are inserting into the trace.
1586        """
1587        value_shape = np.shape(v)
1588
1589        # initialize if necessary
1590        if k not in self.trace_dict:
1591            array_shape = (self._len,) + value_shape
1592            self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype)
1593
1594        # do the actual insertion
1595        if value_shape == ():
1596            self.trace_dict[k][idx] = v
1597        else:
1598            self.trace_dict[k][idx, :] = v
1599
1600
1601def sample_posterior_predictive(
1602    trace,
1603    samples: Optional[int] = None,
1604    model: Optional[Model] = None,
1605    var_names: Optional[List[str]] = None,
1606    size: Optional[int] = None,
1607    keep_size: Optional[bool] = False,
1608    random_seed=None,
1609    progressbar: bool = True,
1610) -> Dict[str, np.ndarray]:
1611    """Generate posterior predictive samples from a model given a trace.
1612
1613    Parameters
1614    ----------
1615    trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace
1616        Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
1617        or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
1618    samples : int
1619        Number of posterior predictive samples to generate. Defaults to one posterior predictive
1620        sample per posterior sample, that is, the number of draws times the number of chains. It
1621        is not recommended to modify this value; when modified, some chains may not be represented
1622        in the posterior predictive sample.
1623    model : Model (optional if in ``with`` context)
1624        Model used to generate ``trace``
1625    vars : iterable
1626        Variables for which to compute the posterior predictive samples.
1627        Deprecated: please use ``var_names`` instead.
1628    var_names : Iterable[str]
1629        Names of variables for which to compute the posterior predictive samples.
1630    size : int
1631        The number of random draws from the distribution specified by the parameters in each
1632        sample of the trace. Not recommended unless more than ndraws times nchains posterior
1633        predictive samples are needed.
1634    keep_size : bool, optional
1635        Force posterior predictive sample to have the same shape as posterior and sample stats
1636        data: ``(nchains, ndraws, ...)``. Overrides samples and size parameters.
1637    random_seed : int
1638        Seed for the random number generator.
1639    progressbar : bool
1640        Whether or not to display a progress bar in the command line. The bar shows the percentage
1641        of completion, the sampling speed in samples per second (SPS), and the estimated remaining
1642        time until completion ("expected time of arrival"; ETA).
1643
1644    Returns
1645    -------
1646    samples : dict
1647        Dictionary with the variable names as keys, and values numpy arrays containing
1648        posterior predictive samples.
1649    """
1650
1651    _trace: Union[MultiTrace, PointList]
1652    if isinstance(trace, InferenceData):
1653        _trace = dataset_to_point_list(trace.posterior)
1654    elif isinstance(trace, xarray.Dataset):
1655        _trace = dataset_to_point_list(trace)
1656    else:
1657        _trace = trace
1658
1659    nchain: int
1660    len_trace: int
1661    if isinstance(trace, (InferenceData, xarray.Dataset)):
1662        nchain, len_trace = chains_and_samples(trace)
1663    else:
1664        len_trace = len(_trace)
1665        try:
1666            nchain = _trace.nchains
1667        except AttributeError:
1668            nchain = 1
1669
1670    if keep_size and samples is not None:
1671        raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
1672    if keep_size and size is not None:
1673        raise IncorrectArgumentsError("Should not specify both keep_size and size arguments")
1674
1675    if samples is None:
1676        if isinstance(_trace, MultiTrace):
1677            samples = sum(len(v) for v in _trace._straces.values())
1678        elif isinstance(_trace, list) and all(isinstance(x, dict) for x in _trace):
1679            # this is a list of points
1680            samples = len(_trace)
1681        else:
1682            raise TypeError(
1683                "Do not know how to compute number of samples for trace argument of type %s"
1684                % type(_trace)
1685            )
1686
1687    assert samples is not None
1688    if samples < len_trace * nchain:
1689        warnings.warn(
1690            "samples parameter is smaller than nchains times ndraws, some draws "
1691            "and/or chains may not be represented in the returned posterior "
1692            "predictive sample"
1693        )
1694
1695    model = modelcontext(model)
1696
1697    if model.potentials:
1698        warnings.warn(
1699            "The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
1700            "This is likely to lead to invalid or biased predictive samples.",
1701            UserWarning,
1702        )
1703
1704    if var_names is not None:
1705        vars_ = [model[x] for x in var_names]
1706    else:
1707        vars_ = model.observed_RVs
1708
1709    if random_seed is not None:
1710        np.random.seed(random_seed)
1711
1712    indices = np.arange(samples)
1713
1714    if progressbar:
1715        indices = progress_bar(indices, total=samples, display=progressbar)
1716
1717    ppc_trace_t = _DefaultTrace(samples)
1718    try:
1719        for idx in indices:
1720            if nchain > 1:
1721                # the trace object will either be a MultiTrace (and have _straces)...
1722                if hasattr(_trace, "_straces"):
1723                    chain_idx, point_idx = np.divmod(idx, len_trace)
1724                    param = cast(MultiTrace, _trace)._straces[chain_idx % nchain].point(point_idx)
1725                # ... or a PointList
1726                else:
1727                    param = cast(PointList, _trace)[idx % (len_trace * nchain)]
1728            # there's only a single chain, but the index might hit it multiple times if
1729            # the number of indices is greater than the length of the trace.
1730            else:
1731                param = _trace[idx % len_trace]
1732
1733            values = draw_values(vars_, point=param, size=size)
1734            for k, v in zip(vars_, values):
1735                ppc_trace_t.insert(k.name, v, idx)
1736    except KeyboardInterrupt:
1737        pass
1738
1739    ppc_trace = ppc_trace_t.trace_dict
1740    if keep_size:
1741        for k, ary in ppc_trace.items():
1742            ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
1743
1744    return ppc_trace
1745
1746
1747def sample_posterior_predictive_w(
1748    traces,
1749    samples: Optional[int] = None,
1750    models: Optional[List[Model]] = None,
1751    weights: Optional[ArrayLike] = None,
1752    random_seed: Optional[int] = None,
1753    progressbar: bool = True,
1754):
1755    """Generate weighted posterior predictive samples from a list of models and
1756    a list of traces according to a set of weights.
1757
1758    Parameters
1759    ----------
1760    traces : list or list of lists
1761        List of traces generated from MCMC sampling (xarray.Dataset, arviz.InferenceData, or
1762        MultiTrace), or a list of list containing dicts from find_MAP() or points. The number of
1763        traces should be equal to the number of weights.
1764    samples : int, optional
1765        Number of posterior predictive samples to generate. Defaults to the
1766        length of the shorter trace in traces.
1767    models : list of Model
1768        List of models used to generate the list of traces. The number of models should be equal to
1769        the number of weights and the number of observed RVs should be the same for all models.
1770        By default a single model will be inferred from ``with`` context, in this case results will
1771        only be meaningful if all models share the same distributions for the observed RVs.
1772    weights : array-like, optional
1773        Individual weights for each trace. Default, same weight for each model.
1774    random_seed : int, optional
1775        Seed for the random number generator.
1776    progressbar : bool, optional default True
1777        Whether or not to display a progress bar in the command line. The bar shows the percentage
1778        of completion, the sampling speed in samples per second (SPS), and the estimated remaining
1779        time until completion ("expected time of arrival"; ETA).
1780
1781    Returns
1782    -------
1783    samples : dict
1784        Dictionary with the variables as keys. The values corresponding to the
1785        posterior predictive samples from the weighted models.
1786    """
1787    np.random.seed(random_seed)
1788
1789    if isinstance(traces[0], InferenceData):
1790        n_samples = [
1791            trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
1792        ]
1793        traces = [dataset_to_point_list(trace.posterior) for trace in traces]
1794    elif isinstance(traces[0], xarray.Dataset):
1795        n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
1796        traces = [dataset_to_point_list(trace) for trace in traces]
1797    else:
1798        n_samples = [len(i) * i.nchains for i in traces]
1799
1800    if models is None:
1801        models = [modelcontext(models)] * len(traces)
1802
1803    for model in models:
1804        if model.potentials:
1805            warnings.warn(
1806                "The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
1807                "This is likely to lead to invalid or biased predictive samples.",
1808                UserWarning,
1809            )
1810            break
1811
1812    if weights is None:
1813        weights = [1] * len(traces)
1814
1815    if len(traces) != len(weights):
1816        raise ValueError("The number of traces and weights should be the same")
1817
1818    if len(models) != len(weights):
1819        raise ValueError("The number of models and weights should be the same")
1820
1821    length_morv = len(models[0].observed_RVs)
1822    if any(len(i.observed_RVs) != length_morv for i in models):
1823        raise ValueError("The number of observed RVs should be the same for all models")
1824
1825    weights = np.asarray(weights)
1826    p = weights / np.sum(weights)
1827
1828    min_tr = min(n_samples)
1829
1830    n = (min_tr * p).astype("int")
1831    # ensure n sum up to min_tr
1832    idx = np.argmax(n)
1833    n[idx] = n[idx] + min_tr - np.sum(n)
1834    trace = []
1835    for i, j in enumerate(n):
1836        tr = traces[i]
1837        len_trace = len(tr)
1838        try:
1839            nchain = tr.nchains
1840        except AttributeError:
1841            nchain = 1
1842
1843        indices = np.random.randint(0, nchain * len_trace, j)
1844        if nchain > 1:
1845            chain_idx, point_idx = np.divmod(indices, len_trace)
1846            for idx in zip(chain_idx, point_idx):
1847                trace.append(tr._straces[idx[0]].point(idx[1]))
1848        else:
1849            for idx in indices:
1850                trace.append(tr[idx])
1851
1852    obs = [x for m in models for x in m.observed_RVs]
1853    variables = np.repeat(obs, n)
1854
1855    lengths = list({np.atleast_1d(observed).shape for observed in obs})
1856
1857    if len(lengths) == 1:
1858        size = [None for i in variables]
1859    elif len(lengths) > 2:
1860        raise ValueError("Observed variables could not be broadcast together")
1861    else:
1862        size = []
1863        x = np.zeros(shape=lengths[0])
1864        y = np.zeros(shape=lengths[1])
1865        b = np.broadcast(x, y)
1866        for var in variables:
1867            shape = np.shape(np.atleast_1d(var.distribution.default()))
1868            if shape != b.shape:
1869                size.append(b.shape)
1870            else:
1871                size.append(None)
1872    len_trace = len(trace)
1873
1874    if samples is None:
1875        samples = len_trace
1876
1877    indices = np.random.randint(0, len_trace, samples)
1878
1879    if progressbar:
1880        indices = progress_bar(indices, total=samples, display=progressbar)
1881
1882    try:
1883        ppc = defaultdict(list)
1884        for idx in indices:
1885            param = trace[idx]
1886            var = variables[idx]
1887            # TODO sample_posterior_predictive_w is currently only work for model with
1888            # one observed.
1889            ppc[var.name].append(draw_values([var], point=param, size=size[idx])[0])
1890
1891    except KeyboardInterrupt:
1892        pass
1893    else:
1894        return {k: np.asarray(v) for k, v in ppc.items()}
1895
1896
1897def sample_prior_predictive(
1898    samples=500,
1899    model: Optional[Model] = None,
1900    var_names: Optional[Iterable[str]] = None,
1901    random_seed=None,
1902) -> Dict[str, np.ndarray]:
1903    """Generate samples from the prior predictive distribution.
1904
1905    Parameters
1906    ----------
1907    samples : int
1908        Number of samples from the prior predictive to generate. Defaults to 500.
1909    model : Model (optional if in ``with`` context)
1910    var_names : Iterable[str]
1911        A list of names of variables for which to compute the posterior predictive
1912        samples. Defaults to both observed and unobserved RVs.
1913    random_seed : int
1914        Seed for the random number generator.
1915
1916    Returns
1917    -------
1918    dict
1919        Dictionary with variable names as keys. The values are numpy arrays of prior
1920        samples.
1921    """
1922    model = modelcontext(model)
1923
1924    if model.potentials:
1925        warnings.warn(
1926            "The effect of Potentials on other parameters is ignored during prior predictive sampling. "
1927            "This is likely to lead to invalid or biased predictive samples.",
1928            UserWarning,
1929        )
1930
1931    if var_names is None:
1932        prior_pred_vars = model.observed_RVs
1933        prior_vars = (
1934            get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
1935        )
1936        vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
1937    else:
1938        vars_ = set(var_names)
1939
1940    if random_seed is not None:
1941        np.random.seed(random_seed)
1942    names = get_default_varnames(vars_, include_transformed=False)
1943    # draw_values fails with auto-transformed variables. transform them later!
1944    values = draw_values([model[name] for name in names], size=samples)
1945
1946    data = {k: v for k, v in zip(names, values)}
1947    if data is None:
1948        raise AssertionError("No variables sampled: attempting to sample %s" % names)
1949
1950    prior: Dict[str, np.ndarray] = {}
1951    for var_name in vars_:
1952        if var_name in data:
1953            prior[var_name] = data[var_name]
1954        elif is_transformed_name(var_name):
1955            untransformed = get_untransformed_name(var_name)
1956            if untransformed in data:
1957                prior[var_name] = model[untransformed].transformation.forward_val(
1958                    data[untransformed]
1959                )
1960    return prior
1961
1962
1963def _init_jitter(model, chains, jitter_max_retries):
1964    """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
1965
1966    pymc3.util.check_start_vals is used to test whether the jittered starting values produce
1967    a finite log probability. Invalid values are resampled unless `jitter_max_retries` is achieved,
1968    in which case the last sampled values are returned.
1969
1970    Parameters
1971    ----------
1972    model : pymc3.Model
1973    chains : int
1974    jitter_max_retries : int
1975        Maximum number of repeated attempts at initializing values (per chain).
1976
1977    Returns
1978    -------
1979    start : ``pymc3.model.Point``
1980        Starting point for sampler
1981    """
1982    start = []
1983    for _ in range(chains):
1984        for i in range(jitter_max_retries + 1):
1985            mean = {var: val.copy() for var, val in model.test_point.items()}
1986            for val in mean.values():
1987                val[...] += 2 * np.random.rand(*val.shape) - 1
1988
1989            if i < jitter_max_retries:
1990                try:
1991                    check_start_vals(mean, model)
1992                except SamplingError:
1993                    pass
1994                else:
1995                    break
1996
1997        start.append(mean)
1998    return start
1999
2000
2001def init_nuts(
2002    init="auto",
2003    chains=1,
2004    n_init=500000,
2005    model=None,
2006    random_seed=None,
2007    progressbar=True,
2008    jitter_max_retries=10,
2009    **kwargs,
2010):
2011    """Set up the mass matrix initialization for NUTS.
2012
2013    NUTS convergence and sampling speed is extremely dependent on the
2014    choice of mass/scaling matrix. This function implements different
2015    methods for choosing or adapting the mass matrix.
2016
2017    Parameters
2018    ----------
2019    init : str
2020        Initialization method to use.
2021
2022        * auto: Choose a default initialization method automatically.
2023          Currently, this is ``jitter+adapt_diag``, but this can change in the future. If you
2024          depend on the exact behaviour, choose an initialization method explicitly.
2025        * adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
2026          variance of the tuning samples. All chains use the test value (usually the prior mean)
2027          as starting point.
2028        * jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in
2029          [-1, 1] as starting point in each chain.
2030        * advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
2031          sample variance of the tuning samples.
2032        * advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
2033          on the variance of the gradients during tuning. This is **experimental** and might be
2034          removed in a future release.
2035        * advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
2036        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
2037        * map: Use the MAP as starting point. This is discouraged.
2038        * adapt_full: Adapt a dense mass matrix using the sample covariances. All chains use the
2039          test value (usually the prior mean) as starting point.
2040        * jitter+adapt_full: Same as ``adapt_full``, but use test value plus a uniform jitter in
2041          [-1, 1] as starting point in each chain.
2042
2043    chains : int
2044        Number of jobs to start.
2045    n_init : int
2046        Number of iterations of initializer. Only works for 'ADVI' init methods.
2047    model : Model (optional if in ``with`` context)
2048    progressbar : bool
2049        Whether or not to display a progressbar for advi sampling.
2050    jitter_max_retries : int
2051        Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
2052        that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
2053        init methods.
2054    **kwargs : keyword arguments
2055        Extra keyword arguments are forwarded to pymc3.NUTS.
2056
2057    Returns
2058    -------
2059    start : ``pymc3.model.Point``
2060        Starting point for sampler
2061    nuts_sampler : ``pymc3.step_methods.NUTS``
2062        Instantiated and initialized NUTS sampler object
2063    """
2064    model = modelcontext(model)
2065
2066    vars = kwargs.get("vars", model.vars)
2067    if set(vars) != set(model.vars):
2068        raise ValueError("Must use init_nuts on all variables of a model.")
2069    if not all_continuous(vars):
2070        raise ValueError("init_nuts can only be used for models with only " "continuous variables.")
2071
2072    if not isinstance(init, str):
2073        raise TypeError("init must be a string.")
2074
2075    if init is not None:
2076        init = init.lower()
2077
2078    if init == "auto":
2079        init = "jitter+adapt_diag"
2080
2081    _log.info(f"Initializing NUTS using {init}...")
2082
2083    if random_seed is not None:
2084        random_seed = int(np.atleast_1d(random_seed)[0])
2085        np.random.seed(random_seed)
2086
2087    cb = [
2088        pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
2089        pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
2090    ]
2091
2092    if init == "adapt_diag":
2093        start = [model.test_point] * chains
2094        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
2095        var = np.ones_like(mean)
2096        potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
2097    elif init == "jitter+adapt_diag":
2098        start = _init_jitter(model, chains, jitter_max_retries)
2099        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
2100        var = np.ones_like(mean)
2101        potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
2102    elif init == "advi+adapt_diag_grad":
2103        approx: pm.MeanField = pm.fit(
2104            random_seed=random_seed,
2105            n=n_init,
2106            method="advi",
2107            model=model,
2108            callbacks=cb,
2109            progressbar=progressbar,
2110            obj_optimizer=pm.adagrad_window,
2111        )
2112        start = approx.sample(draws=chains)
2113        start = list(start)
2114        stds = approx.bij.rmap(approx.std.eval())
2115        cov = model.dict_to_array(stds) ** 2
2116        mean = approx.bij.rmap(approx.mean.get_value())
2117        mean = model.dict_to_array(mean)
2118        weight = 50
2119        potential = quadpotential.QuadPotentialDiagAdaptGrad(model.ndim, mean, cov, weight)
2120    elif init == "advi+adapt_diag":
2121        approx = pm.fit(
2122            random_seed=random_seed,
2123            n=n_init,
2124            method="advi",
2125            model=model,
2126            callbacks=cb,
2127            progressbar=progressbar,
2128            obj_optimizer=pm.adagrad_window,
2129        )
2130        start = approx.sample(draws=chains)
2131        start = list(start)
2132        stds = approx.bij.rmap(approx.std.eval())
2133        cov = model.dict_to_array(stds) ** 2
2134        mean = approx.bij.rmap(approx.mean.get_value())
2135        mean = model.dict_to_array(mean)
2136        weight = 50
2137        potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, cov, weight)
2138    elif init == "advi":
2139        approx = pm.fit(
2140            random_seed=random_seed,
2141            n=n_init,
2142            method="advi",
2143            model=model,
2144            callbacks=cb,
2145            progressbar=progressbar,
2146            obj_optimizer=pm.adagrad_window,
2147        )
2148        start = approx.sample(draws=chains)
2149        start = list(start)
2150        stds = approx.bij.rmap(approx.std.eval())
2151        cov = model.dict_to_array(stds) ** 2
2152        potential = quadpotential.QuadPotentialDiag(cov)
2153    elif init == "advi_map":
2154        start = pm.find_MAP(include_transformed=True)
2155        approx = pm.MeanField(model=model, start=start)
2156        pm.fit(
2157            random_seed=random_seed,
2158            n=n_init,
2159            method=pm.KLqp(approx),
2160            callbacks=cb,
2161            progressbar=progressbar,
2162            obj_optimizer=pm.adagrad_window,
2163        )
2164        start = approx.sample(draws=chains)
2165        start = list(start)
2166        stds = approx.bij.rmap(approx.std.eval())
2167        cov = model.dict_to_array(stds) ** 2
2168        potential = quadpotential.QuadPotentialDiag(cov)
2169    elif init == "map":
2170        start = pm.find_MAP(include_transformed=True)
2171        cov = pm.find_hessian(point=start)
2172        start = [start] * chains
2173        potential = quadpotential.QuadPotentialFull(cov)
2174    elif init == "adapt_full":
2175        start = [model.test_point] * chains
2176        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
2177        cov = np.eye(model.ndim)
2178        potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)
2179    elif init == "jitter+adapt_full":
2180        start = _init_jitter(model, chains, jitter_max_retries)
2181        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
2182        cov = np.eye(model.ndim)
2183        potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)
2184    else:
2185        raise ValueError(f"Unknown initializer: {init}.")
2186
2187    step = pm.NUTS(potential=potential, model=model, **kwargs)
2188
2189    return start, step
2190