import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd
Alternative sampling backends
In Bambi, the sampler used is automatically selected given the type of variables used in the model. For inference, Bambi supports both MCMC and variational inference. By default, Bambi uses PyMC’s implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, it is not the only sampling method, nor is PyMC the only library implementing NUTS.
To this extent, Bambi supports multiple backends for MCMC sampling such as NumPyro and Blackjax. This notebook will cover how to use such alternatives in Bambi.
Note: Bambi utilizes bayeux to access a variety of sampling backends. Thus, you will need to install the optional dependencies in the Bambi pyproject.toml file to use these backends.
Bayeux
Bambi leverages bayeux
to access different sampling backends. In short, bayeux
lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods.
Since the underlying Bambi model is a PyMC model, this PyMC model can be “given” to bayeux
. Then, we can choose from a variety of MCMC methods to perform inference.
To demonstrate the available backends, we will fist simulate data and build a model.
= 100
num_samples = 1
num_features = 1.0
noise_std = 42
random_seed
= np.random.default_rng(random_seed)
rng
= rng.normal(size=num_features)
coefficients = rng.normal(size=(num_samples, num_features))
X = rng.normal(scale=noise_std, size=num_samples)
error = X @ coefficients + error
y
= pd.DataFrame({"y": y, "x": X.flatten()}) data
= bmb.Model("y ~ x", data)
model model.build()
We can call bmb.inference_methods.names
that returns a nested dictionary of the backends and list of inference methods.
= bmb.inference_methods.names
methods methods
{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},
'bayeux': {'mcmc': ['tfp_hmc',
'tfp_nuts',
'tfp_snaper_hmc',
'blackjax_hmc',
'blackjax_chees_hmc',
'blackjax_meads_hmc',
'blackjax_nuts',
'blackjax_hmc_pathfinder',
'blackjax_nuts_pathfinder',
'flowmc_rqspline_hmc',
'flowmc_rqspline_mala',
'flowmc_realnvp_hmc',
'flowmc_realnvp_mala',
'numpyro_hmc',
'numpyro_nuts',
'nutpie']}}
With the PyMC backend, we have access to their implementation of the NUTS sampler and mean-field variational inference.
"pymc"] methods[
{'mcmc': ['mcmc'], 'vi': ['vi']}
bayeux
lets us have access to Tensorflow probability, Blackjax, FlowMC, and NumPyro backends.
"bayeux"] methods[
{'mcmc': ['tfp_hmc',
'tfp_nuts',
'tfp_snaper_hmc',
'blackjax_hmc',
'blackjax_chees_hmc',
'blackjax_meads_hmc',
'blackjax_nuts',
'blackjax_hmc_pathfinder',
'blackjax_nuts_pathfinder',
'flowmc_rqspline_hmc',
'flowmc_rqspline_mala',
'flowmc_realnvp_hmc',
'flowmc_realnvp_mala',
'numpyro_hmc',
'numpyro_nuts',
'nutpie']}
The values of the MCMC and VI keys in the dictionary are the names of the argument you would pass to inference_method
in model.fit
. This is shown in the section below.
Specifying an inference_method
By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the bayeux
MCMC method to the inference_method
parameter of the fit
method.
Blackjax
= model.fit(inference_method="blackjax_nuts")
blackjax_nuts_idata blackjax_nuts_idata
WARNING:2024-12-21 13:43:24,702:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
-
<xarray.Dataset> Size: 100kB Dimensions: (chain: 8, draw: 500) Coordinates: * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * chain (chain) int64 64B 0 1 2 3 4 5 6 7 Data variables: Intercept (chain, draw) float64 32kB -0.02658 0.09092 ... 0.06874 0.01924 sigma (chain, draw) float64 32kB 1.083 0.9101 0.9074 ... 0.9316 1.088 x (chain, draw) float64 32kB 0.2574 0.5978 0.2478 ... 0.5018 0.6094 Attributes: created_at: 2024-12-21T16:43:32.789194+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 200kB Dimensions: (chain: 8, draw: 500) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: acceptance_rate (chain, draw) float64 32kB 1.0 0.9701 ... 0.9816 0.8413 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB 146.2 146.9 ... 144.5 146.6 lp (chain, draw) float64 32kB -145.5 -145.6 ... -144.3 -145.8 n_steps (chain, draw) int64 32kB 7 7 7 1 7 1 7 3 ... 3 23 1 3 3 7 7 step_size (chain, draw) float64 32kB 0.6587 0.6587 ... 0.8076 0.8076 tree_depth (chain, draw) int64 32kB 3 3 3 1 3 1 3 2 ... 2 5 1 2 2 3 3 Attributes: created_at: 2024-12-21T16:43:32.791248+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2024-12-21T16:43:32.789194+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own kwargs
to the fit
method.
The following can be performend to identify the kwargs specific to each method.
"blackjax_nuts") bmb.inference_methods.get_kwargs(
{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, adaptation_info_fn: Callable = <function return_all_adapt_info at 0x7f164c18d120>, integrator=<function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f164c15c680>, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
'is_mass_matrix_diagonal': True,
'initial_step_size': 1.0,
'target_acceptance_rate': 0.8,
'progress_bar': False,
'adaptation_info_fn': <function blackjax.adaptation.base.return_all_adapt_info(state, info, adaptation_state)>,
'algorithm': GenerateSamplingAPI(differentiable=<function as_top_level_api at 0x7f164c16a7a0>, init=<function init at 0x7f164c133380>, build_kernel=<function build_kernel at 0x7f164c169e40>)},
'adapt.run': {'num_steps': 500},
<function blackjax.mcmc.nuts.as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Union[blackjax.mcmc.metrics.Metric, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayLikeTree')], Mapping[Any, ForwardRef('ArrayLikeTree')]]], jax.Array]], *, max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = <function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x7f164c15c680>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10,
'divergence_threshold': 1000,
'integrator': <function blackjax.mcmc.integrators.generate_euclidean_integrator.<locals>.euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,
'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
'step_size': 0.5},
'extra_parameters': {'chain_method': 'vectorized',
'num_chains': 8,
'num_draws': 500,
'num_adapt_draws': 500,
'return_pytree': False}}
Now, we can identify the kwargs we would like to change and pass to the fit
method.
= {
kwargs "adapt.run": {"num_steps": 500},
"num_chains": 4,
"num_draws": 250,
"num_adapt_draws": 250,
}
= model.fit(inference_method="blackjax_nuts", **kwargs)
blackjax_nuts_idata blackjax_nuts_idata
-
<xarray.Dataset> Size: 26kB Dimensions: (chain: 4, draw: 250) Coordinates: * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249 * chain (chain) int64 32B 0 1 2 3 Data variables: Intercept (chain, draw) float64 8kB -0.1701 0.1002 ... 0.09008 -0.07872 sigma (chain, draw) float64 8kB 1.024 0.9962 0.9826 ... 0.9153 1.042 x (chain, draw) float64 8kB 0.468 0.5335 0.4088 ... 0.5823 0.2556 Attributes: created_at: 2024-12-21T16:43:38.392870+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 51kB Dimensions: (chain: 4, draw: 250) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 244 245 246 247 248 249 Data variables: acceptance_rate (chain, draw) float64 8kB 0.972 0.993 0.9295 ... 0.94 1.0 diverging (chain, draw) bool 1kB False False False ... False False energy (chain, draw) float64 8kB 145.9 145.9 145.6 ... 146.0 145.4 lp (chain, draw) float64 8kB -145.5 -144.5 ... -145.3 -145.2 n_steps (chain, draw) int64 8kB 7 3 3 3 3 3 7 7 ... 7 3 3 3 3 3 3 7 step_size (chain, draw) float64 8kB 0.8512 0.8512 ... 0.8232 0.8232 tree_depth (chain, draw) int64 8kB 3 2 2 2 2 2 3 3 ... 3 2 2 2 2 2 2 3 Attributes: created_at: 2024-12-21T16:43:38.394782+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2024-12-21T16:43:38.392870+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
Tensorflow probability
= model.fit(inference_method="tfp_nuts")
tfp_nuts_idata tfp_nuts_idata
-
<xarray.Dataset> Size: 200kB Dimensions: (chain: 8, draw: 1000) Coordinates: * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * chain (chain) int64 64B 0 1 2 3 4 5 6 7 Data variables: Intercept (chain, draw) float64 64kB -0.06265 -0.06601 ... 0.08766 0.08766 sigma (chain, draw) float64 64kB 0.9457 0.9487 0.9521 ... 0.9434 0.9434 x (chain, draw) float64 64kB 0.3832 0.3474 0.276 ... 0.395 0.395 Attributes: created_at: 2024-12-21T16:43:45.717159+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 312kB Dimensions: (chain: 8, draw: 1000) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: accept_ratio (chain, draw) float64 64kB 0.9721 0.9725 ... 0.9694 0.8617 diverging (chain, draw) bool 8kB False False False ... False False is_accepted (chain, draw) bool 8kB True True True ... True True False n_steps (chain, draw) int32 32kB 7 3 7 3 7 7 7 7 ... 7 3 3 3 3 3 7 step_size (chain, draw) float64 64kB 0.563 0.563 0.563 ... nan nan target_log_prob (chain, draw) float64 64kB -144.0 -144.2 ... -144.2 -144.2 tune (chain, draw) float64 64kB 0.0 0.0 0.0 0.0 ... nan nan nan Attributes: created_at: 2024-12-21T16:43:45.718997+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2024-12-21T16:43:45.717159+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
NumPyro
= model.fit(inference_method="numpyro_nuts")
numpyro_nuts_idata numpyro_nuts_idata
sample: 100%|██████████| 1500/1500 [00:03<00:00, 386.97it/s]
-
<xarray.Dataset> Size: 200kB Dimensions: (chain: 8, draw: 1000) Coordinates: * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * chain (chain) int64 64B 0 1 2 3 4 5 6 7 Data variables: Intercept (chain, draw) float64 64kB 0.04368 -0.1021 ... -0.00282 0.1476 sigma (chain, draw) float64 64kB 0.9309 0.9906 0.9233 ... 0.9424 0.9128 x (chain, draw) float64 64kB 0.6003 0.3584 0.5494 ... 0.3202 0.2671 Attributes: created_at: 2024-12-21T16:43:50.477087+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.15.3 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 400kB Dimensions: (chain: 8, draw: 1000) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 64kB 0.9297 0.9775 ... 0.9538 0.7392 diverging (chain, draw) bool 8kB False False False ... False False energy (chain, draw) float64 64kB 145.1 146.0 ... 147.0 147.1 lp (chain, draw) float64 64kB 145.0 144.4 ... 144.1 146.4 n_steps (chain, draw) int64 64kB 7 7 7 7 3 7 7 7 ... 3 3 3 7 7 3 3 step_size (chain, draw) float64 64kB 0.7792 0.7792 ... 0.703 0.703 tree_depth (chain, draw) int64 64kB 3 3 3 3 2 3 3 3 ... 2 2 2 3 3 2 2 Attributes: created_at: 2024-12-21T16:43:50.504626+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.15.3 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2024-12-21T16:43:50.477087+00:00 arviz_version: 0.19.0 inference_library: numpyro inference_library_version: 0.15.3 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
flowMC
= model.fit(inference_method="flowmc_realnvp_hmc")
flowmc_idata flowmc_idata
['n_dim', 'n_chains', 'n_local_steps', 'n_global_steps', 'n_loop', 'output_thinning', 'verbose']
Global Tuning: 100%|██████████| 5/5 [00:20<00:00, 4.05s/it]
Global Sampling: 100%|██████████| 5/5 [00:00<00:00, 26.22it/s]
-
<xarray.Dataset> Size: 244kB Dimensions: (chain: 20, draw: 500) Coordinates: * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * chain (chain) int64 160B 0 1 2 3 4 5 6 7 8 ... 12 13 14 15 16 17 18 19 Data variables: Intercept (chain, draw) float64 80kB 0.2975 0.2975 ... 0.08134 0.03252 sigma (chain, draw) float64 80kB 0.97 0.97 1.024 ... 0.9849 0.9851 x (chain, draw) float64 80kB 0.5371 0.5371 0.5067 ... 0.4151 0.4007 Attributes: created_at: 2024-12-21T16:44:12.534363+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2024-12-21T16:44:12.534363+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
nutpie
"nutpie") bmb.inference_methods.get_kwargs(
{<function nutpie.compiled_pyfunc.from_pyfunc(ndim: int, make_logp_fn: Callable, make_expand_fn: Callable, expanded_dtypes: list[numpy.dtype], expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, initial_mean: numpy.ndarray | None = None, coords: dict[str, typing.Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, typing.Any] | None = None)>: {'ndim': 1,
'make_logp_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.make_logp_fn()>,
'make_expand_fn': <function bayeux._src.mcmc.nutpie._NutpieSampler.get_kwargs.<locals>.make_expand_fn(*args, **kwargs)>,
'expanded_shapes': [(1,)],
'expanded_names': ['x'],
'expanded_dtypes': [numpy.float64]},
<function nutpie.sample.sample(compiled_model: nutpie.sample.CompiledModel, *, draws: int = 1000, tune: int = 300, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, init_mean: Optional[numpy.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, progress_template: Optional[str] = None, progress_style: Optional[str] = None, progress_rate: int = 100, **kwargs) -> arviz.data.inference_data.InferenceData>: {'draws': 1000,
'tune': 300,
'chains': 8,
'cores': 8,
'seed': None,
'save_warmup': True,
'progress_bar': True,
'low_rank_modified_mass_matrix': False,
'init_mean': None,
'return_raw_trace': False,
'blocking': True,
'progress_template': None,
'progress_style': None,
'progress_rate': 100},
'extra_parameters': {'flatten': <function bayeux._src.mcmc.nutpie._NutpieSampler._get_aux.<locals>.flatten(pytree)>,
'unflatten': <jax._src.util.HashablePartial at 0x7f1545283cd0>,
'return_pytree': False}}
= model.fit(inference_method="nutpie", tune=400, draws=500, chains=3)
nutpie_idata nutpie_idata
Sampler Progress
Total Chains: 3
Active Chains: 0
Finished Chains: 3
Sampling for now
Estimated Time to Completion: now
Progress | Draws | Divergences | Step Size | Gradients/Draw |
---|---|---|---|---|
900 | 0 | 1.04 | 3 | |
900 | 0 | 1.02 | 3 | |
900 | 0 | 0.99 | 3 |
-
<xarray.Dataset> Size: 40kB Dimensions: (chain: 3, draw: 500) Coordinates: * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * chain (chain) int64 24B 0 1 2 Data variables: Intercept (chain, draw) float64 12kB 0.08496 -0.02695 ... 0.005357 0.1237 sigma (chain, draw) float64 12kB 1.116 0.89 0.8934 ... 0.9256 0.926 x (chain, draw) float64 12kB 0.3081 0.4959 0.3477 ... 0.4546 0.638 Attributes: created_at: 2024-12-21T16:44:15.471804+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 127kB Dimensions: (chain: 3, draw: 500) Coordinates: * chain (chain) int64 24B 0 1 2 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499 Data variables: depth (chain, draw) uint64 12kB 2 2 2 2 2 2 ... 2 2 2 2 2 2 diverging (chain, draw) bool 2kB False False ... False False energy (chain, draw) float64 12kB 146.6 147.2 ... 144.6 146.6 energy_error (chain, draw) float64 12kB 0.5871 -0.6172 ... 0.704 index_in_trajectory (chain, draw) int64 12kB 2 3 1 -2 -1 ... -2 -1 3 1 -1 logp (chain, draw) float64 12kB -146.1 -144.8 ... -146.2 maxdepth_reached (chain, draw) bool 2kB False False ... False False mean_tree_accept (chain, draw) float64 12kB 0.9476 0.5462 ... 1.0 1.0 mean_tree_accept_sym (chain, draw) float64 12kB 0.8644 0.7061 ... 0.8824 n_steps (chain, draw) uint64 12kB 3 3 3 3 3 3 ... 3 3 3 3 3 3 step_size (chain, draw) float64 12kB 1.039 1.039 ... 0.9917 step_size_bar (chain, draw) float64 12kB 1.039 1.039 ... 0.9917 Attributes: created_at: 2024-12-21T16:44:15.348609+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2024-12-21T16:44:15.471804+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 32kB Dimensions: (chain: 3, draw: 400) Coordinates: * chain (chain) int64 24B 0 1 2 * draw (draw) int64 3kB 0 1 2 3 4 5 6 7 ... 393 394 395 396 397 398 399 Data variables: Intercept (chain, draw) float64 10kB 0.4285 0.4285 ... 0.05143 0.1415 sigma (chain, draw) float64 10kB 1.157 1.157 0.9778 ... 0.7789 0.8057 x (chain, draw) float64 10kB -0.1518 -0.1518 ... 0.5574 0.378 Attributes: created_at: 2024-12-21T16:44:15.473126+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
-
<xarray.Dataset> Size: 102kB Dimensions: (chain: 3, draw: 400) Coordinates: * chain (chain) int64 24B 0 1 2 * draw (draw) int64 3kB 0 1 2 3 4 5 ... 395 396 397 398 399 Data variables: depth (chain, draw) uint64 10kB 2 0 2 1 1 3 ... 2 2 2 2 3 2 diverging (chain, draw) bool 1kB False True ... False False energy (chain, draw) float64 10kB 191.2 163.4 ... 151.0 153.1 energy_error (chain, draw) float64 10kB -0.388 0.0 ... -0.1098 index_in_trajectory (chain, draw) int64 10kB -3 0 -1 0 0 3 ... -1 -2 2 4 1 logp (chain, draw) float64 10kB -161.4 -161.4 ... -149.8 maxdepth_reached (chain, draw) bool 1kB False False ... False False mean_tree_accept (chain, draw) float64 10kB 0.0 0.9011 ... 0.8973 mean_tree_accept_sym (chain, draw) float64 10kB 0.0 0.8825 ... 0.7341 n_steps (chain, draw) uint64 10kB 0 3 1 3 3 2 ... 3 3 3 3 3 7 step_size (chain, draw) float64 10kB 0.4 4.807 ... 0.8206 0.7726 step_size_bar (chain, draw) float64 10kB 0.4 4.807 ... 0.9982 0.9953 Attributes: created_at: 2024-12-21T16:44:15.351287+00:00 arviz_version: 0.19.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev17+g25798ce7
Sampler comparisons
With ArviZ, we can compare the inference result summaries of the samplers. Note: We can’t use az.compare
as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised.
az.summary(blackjax_nuts_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | -0.000 | 0.097 | -0.180 | 0.183 | 0.003 | 0.003 | 938.0 | 752.0 | 1.0 |
sigma | 0.987 | 0.073 | 0.859 | 1.126 | 0.002 | 0.002 | 913.0 | 739.0 | 1.0 |
x | 0.423 | 0.125 | 0.151 | 0.629 | 0.004 | 0.003 | 1044.0 | 820.0 | 1.0 |
az.summary(tfp_nuts_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 0.002 | 0.099 | -0.183 | 0.190 | 0.001 | 0.001 | 6775.0 | 5598.0 | 1.0 |
sigma | 0.987 | 0.071 | 0.848 | 1.114 | 0.001 | 0.001 | 8338.0 | 5715.0 | 1.0 |
x | 0.424 | 0.127 | 0.186 | 0.661 | 0.002 | 0.001 | 6244.0 | 5267.0 | 1.0 |
az.summary(numpyro_nuts_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 0.005 | 0.098 | -0.180 | 0.188 | 0.001 | 0.001 | 9065.0 | 6523.0 | 1.0 |
sigma | 0.988 | 0.074 | 0.856 | 1.127 | 0.001 | 0.001 | 7217.0 | 5477.0 | 1.0 |
x | 0.423 | 0.130 | 0.179 | 0.661 | 0.002 | 0.001 | 7449.0 | 6203.0 | 1.0 |
az.summary(flowmc_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 0.004 | 0.101 | -0.184 | 0.193 | 0.002 | 0.001 | 2352.0 | 3365.0 | 1.01 |
sigma | 0.987 | 0.070 | 0.861 | 1.123 | 0.001 | 0.001 | 4252.0 | 4034.0 | 1.01 |
x | 0.425 | 0.129 | 0.171 | 0.656 | 0.001 | 0.001 | 7504.0 | 3764.0 | 1.01 |
az.summary(nutpie_idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 0.002 | 0.098 | -0.179 | 0.181 | 0.002 | 0.003 | 2288.0 | 1040.0 | 1.0 |
sigma | 0.989 | 0.072 | 0.857 | 1.118 | 0.002 | 0.001 | 2199.0 | 1155.0 | 1.0 |
x | 0.423 | 0.128 | 0.176 | 0.657 | 0.003 | 0.002 | 1956.0 | 1287.0 | 1.0 |
Summary
Thanks to bayeux
, we can use three different sampling backends and 10+ alternative MCMC methods in Bambi. Using these methods is as simple as passing the inference name to the inference_method
of the fit
method.
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sat Dec 21 2024
Python implementation: CPython
Python version : 3.11.9
IPython version : 8.27.0
bambi : 0.14.1.dev17+g25798ce7
arviz : 0.19.0
pandas: 2.2.3
numpy : 1.26.4
Watermark: 2.5.0