Alternative sampling backends

In Bambi, the sampler used is automatically selected given the type of variables used in the model. Bambi supports both MCMC and variational inference for fitting models. By default, Bambi uses PyMC’s implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, PyMC is not the only library implementing NUTS.

To this extent, Bambi also supports the NumPyro, Blackjax, and Nutpie NUTS samplers. This notebook will cover how to use these alternative samplers in Bambi.

Note: To use these samplers, you need to install numpyro, blackjax and or nutpie with a package manager of your choice.

import bambi as bmb
import numpy as np
import pandas as pd

Specifying an inference_method

To demonstrate the different inference methods, we will first simulate data and build a model.

num_samples = 100
num_features = 1
noise_std = 1.0
random_seed = 42

rng = np.random.default_rng(random_seed)

coefficients = rng.normal(size=num_features)
X = rng.normal(size=(num_samples, num_features))
error = rng.normal(scale=noise_std, size=num_samples)
y = X @ coefficients + error

data = pd.DataFrame({"y": y, "x": X.flatten()})
model = bmb.Model("y ~ x", data)

By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the desired MCMC implementation to the inference_method parameter of the fit method.

Blackjax

%%time
blackjax_nuts_idata = model.fit(inference_method="blackjax", progressbar=False)
WARNING:2025-09-28 13:00:20,305:jax._src.xla_bridge:864: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
CPU times: user 7.64 s, sys: 3.05 s, total: 10.7 s
Wall time: 7.19 s
blackjax_nuts_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 104kB
      Dimensions:    (chain: 4, draw: 1000)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          Intercept  (chain, draw) float64 32kB 0.004918 0.02806 ... 0.1205 -0.01849
          x          (chain, draw) float64 32kB 0.3537 0.1788 0.473 ... 0.3752 0.4473
          sigma      (chain, draw) float64 32kB 1.058 1.083 0.9034 ... 1.033 0.9862
      Attributes:
          created_at:                  2025-09-28T16:00:24.433389+00:00
          arviz_version:               0.22.0
          inference_library:           blackjax
          inference_library_version:   1.2.4
          sampling_time:               2.107706
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 172kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 1.0 0.8539 ... 0.9393 0.9043
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 146.1 147.1 ... 145.1 146.3
          lp               (chain, draw) float64 32kB -144.5 -146.4 ... -144.8 -143.8
          n_steps          (chain, draw) int64 32kB 3 3 3 3 3 7 7 3 ... 7 5 7 7 7 7 3
          tree_depth       (chain, draw) int64 32kB 2 2 2 2 2 3 3 2 ... 3 3 3 3 3 3 2
      Attributes:
          created_at:                  2025-09-28T16:00:24.442990+00:00
          arviz_version:               0.22.0
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223
      Attributes:
          created_at:                  2025-09-28T16:00:24.445274+00:00
          arviz_version:               0.22.0
          inference_library:           blackjax
          inference_library_version:   1.2.4
          sampling_time:               2.107706
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

NumPyro

%%time
numpyro_nuts_idata = model.fit(inference_method="numpyro", progressbar=False)
CPU times: user 3.43 s, sys: 294 ms, total: 3.72 s
Wall time: 1.55 s
numpyro_nuts_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 104kB
      Dimensions:    (chain: 4, draw: 1000)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          Intercept  (chain, draw) float64 32kB 0.184 -0.1767 ... -0.1061 0.07233
          x          (chain, draw) float64 32kB 0.3783 0.4742 0.5178 ... 0.5458 0.3218
          sigma      (chain, draw) float64 32kB 0.9515 1.011 1.006 ... 1.018 0.939
      Attributes:
          created_at:                  2025-09-28T16:00:26.212193+00:00
          arviz_version:               0.22.0
          inference_library:           numpyro
          inference_library_version:   0.19.0
          sampling_time:               1.23596
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.8278 1.0 0.9665 ... 1.0 0.814
          step_size        (chain, draw) float64 32kB 0.7569 0.7569 ... 0.9014 0.9014
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 146.4 145.8 ... 145.7 147.1
          n_steps          (chain, draw) int64 32kB 3 7 3 7 7 3 7 7 ... 3 3 3 3 7 1 11
          tree_depth       (chain, draw) int64 32kB 2 3 2 3 3 2 3 3 ... 2 2 2 2 3 1 4
          lp               (chain, draw) float64 32kB 145.7 145.6 ... 145.0 144.5
      Attributes:
          created_at:                  2025-09-28T16:00:26.216944+00:00
          arviz_version:               0.22.0
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223
      Attributes:
          created_at:                  2025-09-28T16:00:26.218137+00:00
          arviz_version:               0.22.0
          inference_library:           numpyro
          inference_library_version:   0.19.0
          sampling_time:               1.23596
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

nutpie

nutpie_idata = model.fit(inference_method="nutpie", progressbar=False)
nutpie_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 136kB
      Dimensions:      (chain: 4, draw: 1000)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
      Data variables:
          sigma_log__  (chain, draw) float64 32kB -0.09139 0.02693 ... 0.1097 -0.1576
          Intercept    (chain, draw) float64 32kB 0.03335 -0.01864 ... 0.07132
          x            (chain, draw) float64 32kB 0.3313 0.53 0.3076 ... 0.5508 0.3157
          sigma        (chain, draw) float64 32kB 0.9127 1.027 0.8574 ... 1.116 0.8542
      Attributes:
          created_at:                  2025-09-28T16:00:34.798058+00:00
          arviz_version:               0.22.0
          inference_library:           nutpie
          inference_library_version:   0.15.2
          sampling_time:               0.0895071029663086
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 336kB
      Dimensions:               (chain: 4, draw: 1000)
      Coordinates:
        * chain                 (chain) int64 32B 0 1 2 3
        * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables:
          depth                 (chain, draw) uint64 32kB 2 2 2 2 2 2 ... 2 2 2 2 2 2
          maxdepth_reached      (chain, draw) bool 4kB False False ... False False
          index_in_trajectory   (chain, draw) int64 32kB 1 -3 2 -2 2 ... 3 -3 -2 2 -3
          logp                  (chain, draw) float64 32kB -144.5 -144.4 ... -146.3
          energy                (chain, draw) float64 32kB 145.0 144.5 ... 146.1 148.9
          diverging             (chain, draw) bool 4kB False False ... False False
          energy_error          (chain, draw) float64 32kB 0.08409 ... -0.05117
          step_size             (chain, draw) float64 32kB 1.097 1.097 ... 1.103 1.103
          step_size_bar         (chain, draw) float64 32kB 1.097 1.097 ... 1.103 1.103
          mean_tree_accept      (chain, draw) float64 32kB 0.8706 1.0 ... 0.8117 0.84
          mean_tree_accept_sym  (chain, draw) float64 32kB 0.9303 0.9425 ... 0.8828
          n_steps               (chain, draw) uint64 32kB 3 3 3 3 3 3 ... 3 3 3 3 3 3
      Attributes:
          created_at:                  2025-09-28T16:00:34.788419+00:00
          arviz_version:               0.22.0
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223
      Attributes:
          created_at:                  2025-09-28T16:00:34.797508+00:00
          arviz_version:               0.22.0
          inference_library:           pymc
          inference_library_version:   3.9.2+2907.g7a3db78e6
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 136kB
      Dimensions:      (chain: 4, draw: 1000)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
      Data variables:
          sigma_log__  (chain, draw) float64 32kB -0.6403 -0.6403 ... -0.1866 -0.1695
          Intercept    (chain, draw) float64 32kB -0.7069 -0.7069 ... -0.06133
          x            (chain, draw) float64 32kB -0.6089 -0.6089 ... 0.5059 0.5003
          sigma        (chain, draw) float64 32kB 0.5271 0.5271 ... 0.8298 0.8441
      Attributes:
          created_at:                  2025-09-28T16:00:34.783875+00:00
          arviz_version:               0.22.0
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

    • <xarray.Dataset> Size: 336kB
      Dimensions:               (chain: 4, draw: 1000)
      Coordinates:
        * chain                 (chain) int64 32B 0 1 2 3
        * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables:
          depth                 (chain, draw) uint64 32kB 3 0 1 2 1 0 ... 2 2 2 2 1 1
          maxdepth_reached      (chain, draw) bool 4kB False False ... False False
          index_in_trajectory   (chain, draw) int64 32kB 0 0 -1 -2 0 ... 2 -3 -2 0 -1
          logp                  (chain, draw) float64 32kB -401.7 -401.7 ... -146.2
          energy                (chain, draw) float64 32kB 402.3 402.8 ... 152.4 147.0
          diverging             (chain, draw) bool 4kB False True ... False False
          energy_error          (chain, draw) float64 32kB 0.0 0.0 ... 0.0 -0.1304
          step_size             (chain, draw) float64 32kB 3.736 0.3684 ... 1.103
          step_size_bar         (chain, draw) float64 32kB 3.736 0.9423 ... 1.103
          mean_tree_accept      (chain, draw) float64 32kB 1.58e-34 0.0 ... 1.0
          mean_tree_accept_sym  (chain, draw) float64 32kB 3.16e-34 0.0 ... 0.9349
          n_steps               (chain, draw) uint64 32kB 7 1 1 3 1 1 ... 3 3 3 3 1 1
      Attributes:
          created_at:                  2025-09-28T16:00:34.792568+00:00
          arviz_version:               0.22.0
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev56+gd93591cd2.d20250927

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Sep 28 2025

Python implementation: CPython
Python version       : 3.13.7
IPython version      : 9.4.0

numpy : 2.3.3
pandas: 2.3.2
bambi : 0.14.1.dev56+gd93591cd2.d20250927

Watermark: 2.5.0