import bambi as bmb
import numpy as np
import pandas as pd
Alternative sampling backends
In Bambi, the sampler used is automatically selected given the type of variables used in the model. Bambi supports both MCMC and variational inference for fitting models. By default, Bambi uses PyMC’s implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, PyMC is not the only library implementing NUTS.
To this extent, Bambi also supports the NumPyro, Blackjax, and Nutpie NUTS samplers. This notebook will cover how to use these alternative samplers in Bambi.
Note: To use these samplers, you need to install numpyro
, blackjax
and or nutpie
with a package manager of your choice.
Specifying an inference_method
To demonstrate the different inference methods, we will first simulate data and build a model.
= 100
num_samples = 1
num_features = 1.0
noise_std = 42
random_seed
= np.random.default_rng(random_seed)
rng
= rng.normal(size=num_features)
coefficients = rng.normal(size=(num_samples, num_features))
X = rng.normal(scale=noise_std, size=num_samples)
error = X @ coefficients + error
y
= pd.DataFrame({"y": y, "x": X.flatten()}) data
= bmb.Model("y ~ x", data) model
By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the desired MCMC implementation to the inference_method
parameter of the fit
method.
Blackjax
%%time
= model.fit(inference_method="blackjax", progressbar=False) blackjax_nuts_idata
CPU times: user 3.24 s, sys: 3.2 s, total: 6.43 s
Wall time: 1.58 s
blackjax_nuts_idata
-
<xarray.Dataset> Size: 104kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: Intercept (chain, draw) float64 32kB -0.1095 0.05797 ... 0.1161 -0.06526 x (chain, draw) float64 32kB 0.4387 0.3546 0.2288 ... 0.3748 0.472 sigma (chain, draw) float64 32kB 1.022 0.9304 0.9587 ... 1.117 0.8317 Attributes: created_at: 2025-09-11T05:42:43.343255+00:00 arviz_version: 0.21.0 inference_library: blackjax inference_library_version: 1.2.5 sampling_time: 0.95643 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 172kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 32kB 0.9694 0.9965 ... 1.0 0.9465 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB 146.4 145.1 ... 148.2 147.5 lp (chain, draw) float64 32kB -144.6 -144.2 ... -146.1 -146.8 n_steps (chain, draw) int64 32kB 3 7 3 7 3 3 3 3 ... 3 7 3 7 7 7 3 tree_depth (chain, draw) int64 32kB 2 3 2 3 2 2 2 2 ... 2 3 2 3 3 3 2 Attributes: created_at: 2025-09-11T05:42:43.346499+00:00 arviz_version: 0.21.0 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2025-09-11T05:42:43.347121+00:00 arviz_version: 0.21.0 inference_library: blackjax inference_library_version: 1.2.5 sampling_time: 0.95643 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
NumPyro
%%time
= model.fit(inference_method="numpyro", progressbar=False) numpyro_nuts_idata
CPU times: user 2.23 s, sys: 111 ms, total: 2.34 s
Wall time: 1.11 s
numpyro_nuts_idata
-
<xarray.Dataset> Size: 104kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: Intercept (chain, draw) float64 32kB -0.09142 0.06636 ... -0.08637 -0.08637 x (chain, draw) float64 32kB 0.4035 0.4687 0.3535 ... 0.2083 0.2083 sigma (chain, draw) float64 32kB 1.0 1.083 1.008 ... 0.9983 0.9983 Attributes: created_at: 2025-09-11T05:42:47.864004+00:00 arviz_version: 0.21.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 0.923537 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 204kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 32kB 1.0 0.9147 1.0 ... 1.0 0.7743 step_size (chain, draw) float64 32kB 0.7985 0.7985 ... 0.8912 0.8912 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB 145.8 145.4 ... 147.9 146.7 n_steps (chain, draw) int64 32kB 3 3 3 3 3 3 3 3 ... 3 3 3 3 3 1 1 tree_depth (chain, draw) int64 32kB 2 2 2 2 2 2 2 2 ... 2 2 2 2 2 1 1 lp (chain, draw) float64 32kB 144.2 145.0 ... 145.5 145.5 Attributes: created_at: 2025-09-11T05:42:47.866678+00:00 arviz_version: 0.21.0 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2025-09-11T05:42:47.867365+00:00 arviz_version: 0.21.0 inference_library: numpyro inference_library_version: 0.19.0 sampling_time: 0.923537 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
nutpie
= model.fit(inference_method="nutpie", progressbar=False) nutpie_idata
/Users/gabestechschulte/projects/bambi/.venv/lib/python3.12/site-packages/pymc/sampling/mcmc.py:335: UserWarning: `var_names` are currently ignored by the nutpie sampler
warnings.warn(
nutpie_idata
-
<xarray.Dataset> Size: 3MB Dimensions: (chain: 4, draw: 1000, __obs__: 100) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99 Data variables: sigma_log__ (chain, draw) float64 32kB -0.01424 -0.0169 ... -0.06667 Intercept (chain, draw) float64 32kB -0.06445 -0.0443 ... 0.09378 x (chain, draw) float64 32kB 0.305 0.1792 0.1207 ... 0.6013 0.288 sigma (chain, draw) float64 32kB 0.9859 0.9832 ... 0.9713 0.9355 mu (chain, draw, __obs__) float64 3MB -0.3816 0.1644 ... -0.01512 Attributes: created_at: 2025-09-11T05:43:00.405915+00:00 arviz_version: 0.21.0 inference_library: nutpie inference_library_version: 0.15.2 sampling_time: 0.03235673904418945 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 336kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: depth (chain, draw) uint64 32kB 2 2 1 2 2 2 ... 2 2 2 2 2 2 maxdepth_reached (chain, draw) bool 4kB False False ... False False index_in_trajectory (chain, draw) int64 32kB 2 1 -1 2 -3 ... 2 -2 -1 2 -2 logp (chain, draw) float64 32kB -144.3 -145.6 ... -145.0 energy (chain, draw) float64 32kB 145.0 145.7 ... 146.6 145.5 diverging (chain, draw) bool 4kB False False ... False False energy_error (chain, draw) float64 32kB -0.1328 0.3533 ... -0.02633 step_size (chain, draw) float64 32kB 1.052 1.052 ... 1.088 1.088 step_size_bar (chain, draw) float64 32kB 1.052 1.052 ... 1.088 1.088 mean_tree_accept (chain, draw) float64 32kB 1.0 0.9008 ... 1.0 1.0 mean_tree_accept_sym (chain, draw) float64 32kB 0.94 0.9265 ... 0.9042 n_steps (chain, draw) uint64 32kB 3 3 1 3 3 3 ... 3 3 3 3 3 3 Attributes: created_at: 2025-09-11T05:43:00.401014+00:00 arviz_version: 0.21.0 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 2kB Dimensions: (__obs__: 100) Coordinates: * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99 Data variables: y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223 Attributes: created_at: 2025-09-11T05:43:00.405522+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.23.0 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 3MB Dimensions: (chain: 4, draw: 1000, __obs__: 100) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99 Data variables: sigma_log__ (chain, draw) float64 32kB -0.6316 -0.6316 ... 0.02711 Intercept (chain, draw) float64 32kB 0.4837 0.4837 ... 0.09586 -0.04978 x (chain, draw) float64 32kB -0.9797 -0.9797 ... 0.4696 0.3999 sigma (chain, draw) float64 32kB 0.5317 0.5317 7.545 ... 0.9999 1.027 mu (chain, draw, __obs__) float64 3MB 1.447 -0.3075 ... -0.1782 Attributes: created_at: 2025-09-11T05:43:00.398914+00:00 arviz_version: 0.21.0 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
-
<xarray.Dataset> Size: 336kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: depth (chain, draw) uint64 32kB 3 0 6 5 1 0 ... 1 2 1 2 2 2 maxdepth_reached (chain, draw) bool 4kB False False ... False False index_in_trajectory (chain, draw) int64 32kB 0 0 -14 -6 0 ... -2 1 -2 2 2 logp (chain, draw) float64 32kB -453.4 -453.4 ... -144.1 energy (chain, draw) float64 32kB 453.6 455.3 ... 144.8 145.3 diverging (chain, draw) bool 4kB False True ... False False energy_error (chain, draw) float64 32kB 0.0 0.0 ... -0.0436 -0.2045 step_size (chain, draw) float64 32kB 3.736 0.3684 ... 1.088 step_size_bar (chain, draw) float64 32kB 3.736 0.9423 ... 1.088 mean_tree_accept (chain, draw) float64 32kB 1.312e-28 0.0 ... 0.9112 mean_tree_accept_sym (chain, draw) float64 32kB 2.624e-28 0.0 ... 0.9075 n_steps (chain, draw) uint64 32kB 7 1 63 31 1 1 ... 3 3 3 3 3 Attributes: created_at: 2025-09-11T05:43:00.403552+00:00 arviz_version: 0.21.0 modeling_interface: bambi modeling_interface_version: 0.15.1.dev17+g2db951a93.d20250910
%load_ext watermark
%watermark -n -u -v -iv -w
The watermark extension is already loaded. To reload it, use:
%reload_ext watermark
Last updated: Thu Sep 11 2025
Python implementation: CPython
Python version : 3.12.5
IPython version : 9.5.0
pandas: 2.3.0
numpy : 2.2.6
arviz : 0.21.0
bambi : 0.15.1.dev17+g2db951a93.d20250910
Watermark: 2.5.0