pymc.sampling.jax.sample_blackjax_nuts#
- pymc.sampling.jax.sample_blackjax_nuts(draws=1000, *, tune=1000, chains=4, target_accept=0.8, random_seed=None, initvals=None, jitter=True, model=None, var_names=None, nuts_kwargs=None, progressbar=True, keep_untransformed=False, chain_method='parallel', postprocessing_backend=None, postprocessing_vectorize=None, postprocessing_chunks=None, idata_kwargs=None, compute_convergence_checks=True, nuts_sampler='blackjax')#
- Draw samples from the posterior using a jax NUTS method. - Parameters:
- drawsint, default 1000
- The number of samples to draw. The number of tuned samples are discarded by default. 
- tuneint, default 1000
- Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the - drawsargument. Tuned samples are discarded.
- chainsint, default 4
- The number of chains to sample. 
- target_acceptfloatin[0, 1].
- The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. 
- random_seedint,RandomStateorGenerator, optional
- Random seed used by the sampling steps. 
- initvals: StartDict or Sequence[Optional[StartDict]], optional
- Initial values for random variables provided as a dictionary (or sequence of dictionaries) mapping the random variable (by name or reference) to desired starting values. 
- jitter: bool, default True
- If True, add jitter to initial points. 
- modelModel, optional
- Model to sample from. The model needs to have free random variables. When inside a - withmodel context, it defaults to that model, otherwise the model must be passed explicitly.
- var_namessequence of str, optional
- Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior. 
- nuts_kwargsdict, optional
- Keyword arguments for the underlying nuts sampler 
- progressbarbool, default True
- If True, display a progressbar while sampling 
- keep_untransformedbool, default False
- Include untransformed variables in the posterior samples. 
- chain_methodstr, default “parallel”
- Specify how samples should be drawn. The choices include “parallel”, and “vectorized”. 
- postprocessing_backendOptional[Literal[“cpu”, “gpu”]], default None,
- Specify how postprocessing should be computed. gpu or cpu 
- postprocessing_vectorizeLiteral[“vmap”, “scan”], default “scan”
- How to vectorize the postprocessing: vmap or sequential scan 
- postprocessing_chunksNone
- This argument is deprecated 
- idata_kwargsdict, optional
- Keyword arguments for - arviz.from_dict(). It also accepts a boolean as value for the- log_likelihoodkey to indicate that the pointwise log likelihood should not be included in the returned object. Values for- observed_data,- constant_data,- coords, and- dimsare inferred from the- modelargument if not provided in- idata_kwargs. If- coordsand- dimsare provided, they are used to update the inferred dictionaries.
- compute_convergence_checksbool, default True
- If True, compute ess and rhat values and warn if they indicate potential sampling issues. 
- nuts_samplerLiteral[“numpyro”, “blackjax”]
- Nuts sampler library to use - do not change - use sample_numpyro_nuts or sample_blackjax_nuts as appropriate 
 
- draws
- Returns:
- InferenceData
- ArviZ - InferenceDataobject that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped with- idata_kwargs).