Alternative sampling backends

In Bambi, the sampler used is automatically selected given the type of variables used in the model. Bambi supports both MCMC and variational inference for fitting models. By default, Bambi uses PyMC’s implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, PyMC is not the only library implementing NUTS.

To this extent, Bambi also supports the NumPyro, Blackjax, and Nutpie NUTS samplers. This notebook will cover how to use these alternative samplers in Bambi.

Note: To use these samplers, you need to install numpyro, blackjax and or nutpie with a package manager of your choice.

import bambi as bmb
import numpy as np
import pandas as pd

Specifying an inference_method

To demonstrate the different inference methods, we will first simulate data and build a model.

num_samples = 100
num_features = 1
noise_std = 1.0
random_seed = 42

rng = np.random.default_rng(random_seed)

coefficients = rng.normal(size=num_features)
X = rng.normal(size=(num_samples, num_features))
error = rng.normal(scale=noise_std, size=num_samples)
y = X @ coefficients + error

data = pd.DataFrame({"y": y, "x": X.flatten()})
model = bmb.Model("y ~ x", data)

By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the desired MCMC implementation to the inference_method parameter of the fit method.

Blackjax

%%time
blackjax_nuts_idata = model.fit(inference_method="blackjax", progressbar=False)
CPU times: user 3.24 s, sys: 3.2 s, total: 6.43 s
Wall time: 1.58 s
blackjax_nuts_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 104kB
      Dimensions:    (chain: 4, draw: 1000)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          Intercept  (chain, draw) float64 32kB -0.1095 0.05797 ... 0.1161 -0.06526
          x          (chain, draw) float64 32kB 0.4387 0.3546 0.2288 ... 0.3748 0.472
          sigma      (chain, draw) float64 32kB 1.022 0.9304 0.9587 ... 1.117 0.8317
      Attributes:
          created_at:                  2025-09-11T05:42:43.343255+00:00
          arviz_version:               0.21.0
          inference_library:           blackjax
          inference_library_version:   1.2.5
          sampling_time:               0.95643
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 172kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.9694 0.9965 ... 1.0 0.9465
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 146.4 145.1 ... 148.2 147.5
          lp               (chain, draw) float64 32kB -144.6 -144.2 ... -146.1 -146.8
          n_steps          (chain, draw) int64 32kB 3 7 3 7 3 3 3 3 ... 3 7 3 7 7 7 3
          tree_depth       (chain, draw) int64 32kB 2 3 2 3 2 2 2 2 ... 2 3 2 3 3 3 2
      Attributes:
          created_at:                  2025-09-11T05:42:43.346499+00:00
          arviz_version:               0.21.0
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223
      Attributes:
          created_at:                  2025-09-11T05:42:43.347121+00:00
          arviz_version:               0.21.0
          inference_library:           blackjax
          inference_library_version:   1.2.5
          sampling_time:               0.95643
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

NumPyro

%%time
numpyro_nuts_idata = model.fit(inference_method="numpyro", progressbar=False)
CPU times: user 2.23 s, sys: 111 ms, total: 2.34 s
Wall time: 1.11 s
numpyro_nuts_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 104kB
      Dimensions:    (chain: 4, draw: 1000)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          Intercept  (chain, draw) float64 32kB -0.09142 0.06636 ... -0.08637 -0.08637
          x          (chain, draw) float64 32kB 0.4035 0.4687 0.3535 ... 0.2083 0.2083
          sigma      (chain, draw) float64 32kB 1.0 1.083 1.008 ... 0.9983 0.9983
      Attributes:
          created_at:                  2025-09-11T05:42:47.864004+00:00
          arviz_version:               0.21.0
          inference_library:           numpyro
          inference_library_version:   0.19.0
          sampling_time:               0.923537
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 1.0 0.9147 1.0 ... 1.0 0.7743
          step_size        (chain, draw) float64 32kB 0.7985 0.7985 ... 0.8912 0.8912
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 145.8 145.4 ... 147.9 146.7
          n_steps          (chain, draw) int64 32kB 3 3 3 3 3 3 3 3 ... 3 3 3 3 3 1 1
          tree_depth       (chain, draw) int64 32kB 2 2 2 2 2 2 2 2 ... 2 2 2 2 2 1 1
          lp               (chain, draw) float64 32kB 144.2 145.0 ... 145.5 145.5
      Attributes:
          created_at:                  2025-09-11T05:42:47.866678+00:00
          arviz_version:               0.21.0
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223
      Attributes:
          created_at:                  2025-09-11T05:42:47.867365+00:00
          arviz_version:               0.21.0
          inference_library:           numpyro
          inference_library_version:   0.19.0
          sampling_time:               0.923537
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

nutpie

nutpie_idata = model.fit(inference_method="nutpie", progressbar=False)
/Users/gabestechschulte/projects/bambi/.venv/lib/python3.12/site-packages/pymc/sampling/mcmc.py:335: UserWarning: `var_names` are currently ignored by the nutpie sampler
  warnings.warn(
nutpie_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 3MB
      Dimensions:      (chain: 4, draw: 1000, __obs__: 100)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * __obs__      (__obs__) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99
      Data variables:
          sigma_log__  (chain, draw) float64 32kB -0.01424 -0.0169 ... -0.06667
          Intercept    (chain, draw) float64 32kB -0.06445 -0.0443 ... 0.09378
          x            (chain, draw) float64 32kB 0.305 0.1792 0.1207 ... 0.6013 0.288
          sigma        (chain, draw) float64 32kB 0.9859 0.9832 ... 0.9713 0.9355
          mu           (chain, draw, __obs__) float64 3MB -0.3816 0.1644 ... -0.01512
      Attributes:
          created_at:                  2025-09-11T05:43:00.405915+00:00
          arviz_version:               0.21.0
          inference_library:           nutpie
          inference_library_version:   0.15.2
          sampling_time:               0.03235673904418945
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 336kB
      Dimensions:               (chain: 4, draw: 1000)
      Coordinates:
        * chain                 (chain) int64 32B 0 1 2 3
        * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables:
          depth                 (chain, draw) uint64 32kB 2 2 1 2 2 2 ... 2 2 2 2 2 2
          maxdepth_reached      (chain, draw) bool 4kB False False ... False False
          index_in_trajectory   (chain, draw) int64 32kB 2 1 -1 2 -3 ... 2 -2 -1 2 -2
          logp                  (chain, draw) float64 32kB -144.3 -145.6 ... -145.0
          energy                (chain, draw) float64 32kB 145.0 145.7 ... 146.6 145.5
          diverging             (chain, draw) bool 4kB False False ... False False
          energy_error          (chain, draw) float64 32kB -0.1328 0.3533 ... -0.02633
          step_size             (chain, draw) float64 32kB 1.052 1.052 ... 1.088 1.088
          step_size_bar         (chain, draw) float64 32kB 1.052 1.052 ... 1.088 1.088
          mean_tree_accept      (chain, draw) float64 32kB 1.0 0.9008 ... 1.0 1.0
          mean_tree_accept_sym  (chain, draw) float64 32kB 0.94 0.9265 ... 0.9042
          n_steps               (chain, draw) uint64 32kB 3 3 1 3 3 3 ... 3 3 3 3 3 3
      Attributes:
          created_at:                  2025-09-11T05:43:00.401014+00:00
          arviz_version:               0.21.0
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223
      Attributes:
          created_at:                  2025-09-11T05:43:00.405522+00:00
          arviz_version:               0.21.0
          inference_library:           pymc
          inference_library_version:   5.23.0
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 3MB
      Dimensions:      (chain: 4, draw: 1000, __obs__: 100)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * __obs__      (__obs__) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99
      Data variables:
          sigma_log__  (chain, draw) float64 32kB -0.6316 -0.6316 ... 0.02711
          Intercept    (chain, draw) float64 32kB 0.4837 0.4837 ... 0.09586 -0.04978
          x            (chain, draw) float64 32kB -0.9797 -0.9797 ... 0.4696 0.3999
          sigma        (chain, draw) float64 32kB 0.5317 0.5317 7.545 ... 0.9999 1.027
          mu           (chain, draw, __obs__) float64 3MB 1.447 -0.3075 ... -0.1782
      Attributes:
          created_at:                  2025-09-11T05:43:00.398914+00:00
          arviz_version:               0.21.0
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

    • <xarray.Dataset> Size: 336kB
      Dimensions:               (chain: 4, draw: 1000)
      Coordinates:
        * chain                 (chain) int64 32B 0 1 2 3
        * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables:
          depth                 (chain, draw) uint64 32kB 3 0 6 5 1 0 ... 1 2 1 2 2 2
          maxdepth_reached      (chain, draw) bool 4kB False False ... False False
          index_in_trajectory   (chain, draw) int64 32kB 0 0 -14 -6 0 ... -2 1 -2 2 2
          logp                  (chain, draw) float64 32kB -453.4 -453.4 ... -144.1
          energy                (chain, draw) float64 32kB 453.6 455.3 ... 144.8 145.3
          diverging             (chain, draw) bool 4kB False True ... False False
          energy_error          (chain, draw) float64 32kB 0.0 0.0 ... -0.0436 -0.2045
          step_size             (chain, draw) float64 32kB 3.736 0.3684 ... 1.088
          step_size_bar         (chain, draw) float64 32kB 3.736 0.9423 ... 1.088
          mean_tree_accept      (chain, draw) float64 32kB 1.312e-28 0.0 ... 0.9112
          mean_tree_accept_sym  (chain, draw) float64 32kB 2.624e-28 0.0 ... 0.9075
          n_steps               (chain, draw) uint64 32kB 7 1 63 31 1 1 ... 3 3 3 3 3
      Attributes:
          created_at:                  2025-09-11T05:43:00.403552+00:00
          arviz_version:               0.21.0
          modeling_interface:          bambi
          modeling_interface_version:  0.15.1.dev17+g2db951a93.d20250910

%load_ext watermark
%watermark -n -u -v -iv -w
The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Last updated: Thu Sep 11 2025

Python implementation: CPython
Python version       : 3.12.5
IPython version      : 9.5.0

pandas: 2.3.0
numpy : 2.2.6
arviz : 0.21.0
bambi : 0.15.1.dev17+g2db951a93.d20250910

Watermark: 2.5.0