import bambi as bmb
import numpy as np
import pandas as pd
Alternative sampling backends
In Bambi, the sampler used is automatically selected given the type of variables used in the model. Bambi supports both MCMC and variational inference for fitting models. By default, Bambi uses PyMC’s implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, PyMC is not the only library implementing NUTS.
To this extent, Bambi also supports the NumPyro, Blackjax, and Nutpie NUTS samplers. This notebook will cover how to use these alternative samplers in Bambi.
Note: To use these samplers, you need to install numpyro
, blackjax
and or nutpie
with a package manager of your choice.
Specifying an inference_method
To demonstrate the different inference methods, we will first simulate data and build a model.
= 100
num_samples = 1
num_features = 1.0
noise_std = 42
random_seed
= np.random.default_rng(random_seed)
rng
= rng.normal(size=num_features)
coefficients = rng.normal(size=(num_samples, num_features))
X = rng.normal(scale=noise_std, size=num_samples)
error = X @ coefficients + error
y
= pd.DataFrame({"y": y, "x": X.flatten()}) data
= bmb.Model("y ~ x", data) model
By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the desired MCMC implementation to the inference_method
parameter of the fit
method.
Blackjax
%%time
= model.fit(inference_method="blackjax", progressbar=False) blackjax_nuts_idata
WARNING:2025-09-28 13:00:20,305:jax._src.xla_bridge:864: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
CPU times: user 7.64 s, sys: 3.05 s, total: 10.7 s
Wall time: 7.19 s
blackjax_nuts_idata
-
<xarray.Dataset> Size: 104kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: Intercept (chain, draw) float64 32kB 0.004918 0.02806 ... 0.1205 -0.01849 x (chain, draw) float64 32kB 0.3537 0.1788 0.473 ... 0.3752 0.4473 sigma (chain, draw) float64 32kB 1.058 1.083 0.9034 ... 1.033 0.9862 Attributes: created_at: 2025-09-28T16:00:24.433389+00:00 arviz_version: 0.22.0 inference_library: blackjax inference_library_version: 1.2.4 sampling_time: 2.107706 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 172kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 32kB 1.0 0.8539 ... 0.9393 0.9043 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB 146.1 147.1 ... 145.1 146.3 lp (chain, draw) float64 32kB -144.5 -146.4 ... -144.8 -143.8 n_steps (chain, draw) int64 32kB 3 3 3 3 3 7 7 3 ... 7 5 7 7 7 7 3 tree_depth (chain, draw) int64 32kB 2 2 2 2 2 3 3 2 ... 3 3 3 3 3 3 2 Attributes: created_at: 2025-09-28T16:00:24.442990+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2025-09-28T16:00:24.445274+00:00 arviz_version: 0.22.0 inference_library: blackjax inference_library_version: 1.2.4 sampling_time: 2.107706 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
NumPyro
%%time
= model.fit(inference_method="numpyro", progressbar=False) numpyro_nuts_idata
CPU times: user 3.43 s, sys: 294 ms, total: 3.72 s
Wall time: 1.55 s
numpyro_nuts_idata
-
<xarray.Dataset> Size: 104kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: Intercept (chain, draw) float64 32kB 0.184 -0.1767 ... -0.1061 0.07233 x (chain, draw) float64 32kB 0.3783 0.4742 0.5178 ... 0.5458 0.3218 sigma (chain, draw) float64 32kB 0.9515 1.011 1.006 ... 1.018 0.939 Attributes: created_at: 2025-09-28T16:00:26.212193+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 1.23596 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 204kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 32kB 0.8278 1.0 0.9665 ... 1.0 0.814 step_size (chain, draw) float64 32kB 0.7569 0.7569 ... 0.9014 0.9014 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB 146.4 145.8 ... 145.7 147.1 n_steps (chain, draw) int64 32kB 3 7 3 7 7 3 7 7 ... 3 3 3 3 7 1 11 tree_depth (chain, draw) int64 32kB 2 3 2 3 3 2 3 3 ... 2 2 2 2 3 1 4 lp (chain, draw) float64 32kB 145.7 145.6 ... 145.0 144.5 Attributes: created_at: 2025-09-28T16:00:26.216944+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2025-09-28T16:00:26.218137+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 1.23596 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
nutpie
= model.fit(inference_method="nutpie", progressbar=False) nutpie_idata
nutpie_idata
-
<xarray.Dataset> Size: 136kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: sigma_log__ (chain, draw) float64 32kB -0.09139 0.02693 ... 0.1097 -0.1576 Intercept (chain, draw) float64 32kB 0.03335 -0.01864 ... 0.07132 x (chain, draw) float64 32kB 0.3313 0.53 0.3076 ... 0.5508 0.3157 sigma (chain, draw) float64 32kB 0.9127 1.027 0.8574 ... 1.116 0.8542 Attributes: created_at: 2025-09-28T16:00:34.798058+00:00 arviz_version: 0.22.0 inference_library: nutpie inference_library_version: 0.15.2 sampling_time: 0.0895071029663086 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 336kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: depth (chain, draw) uint64 32kB 2 2 2 2 2 2 ... 2 2 2 2 2 2 maxdepth_reached (chain, draw) bool 4kB False False ... False False index_in_trajectory (chain, draw) int64 32kB 1 -3 2 -2 2 ... 3 -3 -2 2 -3 logp (chain, draw) float64 32kB -144.5 -144.4 ... -146.3 energy (chain, draw) float64 32kB 145.0 144.5 ... 146.1 148.9 diverging (chain, draw) bool 4kB False False ... False False energy_error (chain, draw) float64 32kB 0.08409 ... -0.05117 step_size (chain, draw) float64 32kB 1.097 1.097 ... 1.103 1.103 step_size_bar (chain, draw) float64 32kB 1.097 1.097 ... 1.103 1.103 mean_tree_accept (chain, draw) float64 32kB 0.8706 1.0 ... 0.8117 0.84 mean_tree_accept_sym (chain, draw) float64 32kB 0.9303 0.9425 ... 0.8828 n_steps (chain, draw) uint64 32kB 3 3 3 3 3 3 ... 3 3 3 3 3 3 Attributes: created_at: 2025-09-28T16:00:34.788419+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2025-09-28T16:00:34.797508+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 3.9.2+2907.g7a3db78e6 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 136kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: sigma_log__ (chain, draw) float64 32kB -0.6403 -0.6403 ... -0.1866 -0.1695 Intercept (chain, draw) float64 32kB -0.7069 -0.7069 ... -0.06133 x (chain, draw) float64 32kB -0.6089 -0.6089 ... 0.5059 0.5003 sigma (chain, draw) float64 32kB 0.5271 0.5271 ... 0.8298 0.8441 Attributes: created_at: 2025-09-28T16:00:34.783875+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
-
<xarray.Dataset> Size: 336kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: depth (chain, draw) uint64 32kB 3 0 1 2 1 0 ... 2 2 2 2 1 1 maxdepth_reached (chain, draw) bool 4kB False False ... False False index_in_trajectory (chain, draw) int64 32kB 0 0 -1 -2 0 ... 2 -3 -2 0 -1 logp (chain, draw) float64 32kB -401.7 -401.7 ... -146.2 energy (chain, draw) float64 32kB 402.3 402.8 ... 152.4 147.0 diverging (chain, draw) bool 4kB False True ... False False energy_error (chain, draw) float64 32kB 0.0 0.0 ... 0.0 -0.1304 step_size (chain, draw) float64 32kB 3.736 0.3684 ... 1.103 step_size_bar (chain, draw) float64 32kB 3.736 0.9423 ... 1.103 mean_tree_accept (chain, draw) float64 32kB 1.58e-34 0.0 ... 1.0 mean_tree_accept_sym (chain, draw) float64 32kB 3.16e-34 0.0 ... 0.9349 n_steps (chain, draw) uint64 32kB 7 1 1 3 1 1 ... 3 3 3 3 1 1 Attributes: created_at: 2025-09-28T16:00:34.792568+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.14.1.dev56+gd93591cd2.d20250927
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Sep 28 2025
Python implementation: CPython
Python version : 3.13.7
IPython version : 9.4.0
numpy : 2.3.3
pandas: 2.3.2
bambi : 0.14.1.dev56+gd93591cd2.d20250927
Watermark: 2.5.0