Alternative sampling backends

In Bambi, the sampler used is automatically selected given the type of variables used in the model. For inference, Bambi supports both MCMC and variational inference. By default, Bambi uses PyMC’s implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, it is not the only sampling method, nor is PyMC the only library implementing NUTS.

To this extent, Bambi supports multiple backends for MCMC sampling such as NumPyro and Blackjax. This notebook will cover how to use such alternatives in Bambi.

Note: Bambi utilizes bayeux to access a variety of sampling backends. Thus, you will need to install the optional dependencies in the Bambi pyproject.toml file to use these backends.

import arviz as az
import bambi as bmb
import bayeux as bx
import numpy as np
import pandas as pd

bayeux

Bambi leverages bayeux to access different sampling backends. In short, bayeux lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods.

Since the underlying Bambi model is a PyMC model, this PyMC model can be “given” to bayeux. Then, we can choose from a variety of MCMC methods to perform inference.

To demonstrate the available backends, we will fist simulate data and build a model.

num_samples = 100
num_features = 1
noise_std = 1.0
random_seed = 42

np.random.seed(random_seed)

coefficients = np.random.randn(num_features)
X = np.random.randn(num_samples, num_features)
error = np.random.normal(scale=noise_std, size=num_samples)
y = X @ coefficients + error

data = pd.DataFrame({"y": y, "x": X.flatten()})
model = bmb.Model("y ~ x", data)
model.build()

We can call bmb.inference_methods.names that returns a nested dictionary of the backends and list of inference methods.

methods = bmb.inference_methods.names
methods
{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},
 'bayeux': {'mcmc': ['tfp_hmc',
   'tfp_nuts',
   'tfp_snaper_hmc',
   'blackjax_hmc',
   'blackjax_chees_hmc',
   'blackjax_meads_hmc',
   'blackjax_nuts',
   'blackjax_hmc_pathfinder',
   'blackjax_nuts_pathfinder',
   'flowmc_rqspline_hmc',
   'flowmc_rqspline_mala',
   'flowmc_realnvp_hmc',
   'flowmc_realnvp_mala',
   'numpyro_hmc',
   'numpyro_nuts']}}

With the PyMC backend, we have access to their implementation of the NUTS sampler and mean-field variational inference.

methods["pymc"]
{'mcmc': ['mcmc'], 'vi': ['vi']}

bayeux lets us have access to Tensorflow probability, Blackjax, FlowMC, and NumPyro backends.

methods["bayeux"]
{'mcmc': ['tfp_hmc',
  'tfp_nuts',
  'tfp_snaper_hmc',
  'blackjax_hmc',
  'blackjax_chees_hmc',
  'blackjax_meads_hmc',
  'blackjax_nuts',
  'blackjax_hmc_pathfinder',
  'blackjax_nuts_pathfinder',
  'flowmc_rqspline_hmc',
  'flowmc_rqspline_mala',
  'flowmc_realnvp_hmc',
  'flowmc_realnvp_mala',
  'numpyro_hmc',
  'numpyro_nuts']}

The values of the MCMC and VI keys in the dictionary are the names of the argument you would pass to inference_method in model.fit. This is shown in the section below.

Specifying an inference_method

By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the bayeux MCMC method to the inference_method parameter of the fit method.

Blackjax

blackjax_nuts_idata = model.fit(inference_method="blackjax_nuts")
blackjax_nuts_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 100kB
      Dimensions:    (chain: 8, draw: 500)
      Coordinates:
        * chain      (chain) int64 64B 0 1 2 3 4 5 6 7
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          Intercept  (chain, draw) float64 32kB -0.01389 0.1089 ... -0.000227 0.1499
          sigma      (chain, draw) float64 32kB 1.093 0.8295 1.09 ... 0.8503 0.9044
          x          (chain, draw) float64 32kB 0.3531 0.3635 0.3556 ... 0.3502 0.3066
      Attributes:
          created_at:                  2024-06-02T15:41:30.853458+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 200kB
      Dimensions:          (chain: 8, draw: 500)
      Coordinates:
        * chain            (chain) int64 64B 0 1 2 3 4 5 6 7
        * draw             (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.8636 0.9972 ... 0.8985 0.9411
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 142.0 142.5 ... 142.5 141.4
          lp               (chain, draw) float64 32kB -141.7 -141.2 ... -140.2 -140.6
          n_steps          (chain, draw) int64 32kB 7 7 7 7 7 3 7 7 ... 1 3 7 3 3 7 7
          step_size        (chain, draw) float64 32kB 0.6802 0.6802 ... 0.8167 0.8167
          tree_depth       (chain, draw) int64 32kB 3 3 3 3 3 2 3 3 ... 1 2 3 2 2 3 3
      Attributes:
          created_at:                  2024-06-02T15:41:30.856930+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452
      Attributes:
          created_at:                  2024-06-02T15:41:30.853458+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own kwargs to the fit method.

The following can be performend to identify the kwargs specific to each method.

bmb.inference_methods.get_kwargs("blackjax_nuts")
{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
  'is_mass_matrix_diagonal': True,
  'initial_step_size': 1.0,
  'target_acceptance_rate': 0.8,
  'progress_bar': False,
  'algorithm': GenerateSamplingAPI(differentiable=<function as_top_level_api at 0x77fd1f34da80>, init=<function init at 0x77fd1f322a20>, build_kernel=<function build_kernel at 0x77fd1f333b00>)},
 'adapt.run': {'num_steps': 500},
 <function blackjax.mcmc.nuts.as_top_level_api(logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Union[blackjax.mcmc.metrics.Metric, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, Iterable[ForwardRef('ArrayLikeTree')], Mapping[Any, ForwardRef('ArrayLikeTree')]]], jax.Array]], *, max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = <function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x77fd1f323c40>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10,
  'divergence_threshold': 1000,
  'integrator': <function blackjax.mcmc.integrators.generate_euclidean_integrator.<locals>.euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,
  'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
  'step_size': 0.5},
 'extra_parameters': {'chain_method': 'vectorized',
  'num_chains': 8,
  'num_draws': 500,
  'num_adapt_draws': 500,
  'return_pytree': False}}

Now, we can identify the kwargs we would like to change and pass to the fit method.

kwargs = {
    "adapt.run": {"num_steps": 500},
    "num_chains": 4,
    "num_draws": 250,
    "num_adapt_draws": 250
}

blackjax_nuts_idata = model.fit(inference_method="blackjax_nuts", **kwargs)
blackjax_nuts_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 26kB
      Dimensions:    (chain: 4, draw: 250)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249
      Data variables:
          Intercept  (chain, draw) float64 8kB 0.1186 0.1811 0.1516 ... 0.104 -0.01889
          sigma      (chain, draw) float64 8kB 0.9543 0.976 0.9225 ... 0.8462 0.9206
          x          (chain, draw) float64 8kB 0.1962 0.2625 0.2581 ... 0.3441 0.3412
      Attributes:
          created_at:                  2024-06-02T15:41:41.635714+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 51kB
      Dimensions:          (chain: 4, draw: 250)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 2kB 0 1 2 3 4 5 6 ... 244 245 246 247 248 249
      Data variables:
          acceptance_rate  (chain, draw) float64 8kB 0.9572 0.9862 1.0 ... 0.921 1.0
          diverging        (chain, draw) bool 1kB False False False ... False False
          energy           (chain, draw) float64 8kB 144.3 142.1 141.9 ... 140.8 140.6
          lp               (chain, draw) float64 8kB -141.3 -141.4 ... -140.7 -139.4
          n_steps          (chain, draw) int64 8kB 3 7 3 3 7 3 3 3 ... 3 3 7 3 7 7 1 7
          step_size        (chain, draw) float64 8kB 0.8903 0.8903 ... 0.7551 0.7551
          tree_depth       (chain, draw) int64 8kB 2 3 2 2 3 2 2 2 ... 2 2 3 2 3 3 1 3
      Attributes:
          created_at:                  2024-06-02T15:41:41.639906+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452
      Attributes:
          created_at:                  2024-06-02T15:41:41.635714+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

Tensorflow probability

tfp_nuts_idata = model.fit(inference_method="tfp_nuts")
tfp_nuts_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 200kB
      Dimensions:    (chain: 8, draw: 1000)
      Coordinates:
        * chain      (chain) int64 64B 0 1 2 3 4 5 6 7
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          Intercept  (chain, draw) float64 64kB 0.2415 -0.0268 ... -0.07376 -0.05367
          sigma      (chain, draw) float64 64kB 0.9948 0.9385 0.9726 ... 0.8749 1.129
          x          (chain, draw) float64 64kB 0.3051 0.3062 0.1433 ... 0.2551 0.5439
      Attributes:
          created_at:                  2024-06-02T15:41:52.350361+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 312kB
      Dimensions:          (chain: 8, draw: 1000)
      Coordinates:
        * chain            (chain) int64 64B 0 1 2 3 4 5 6 7
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          accept_ratio     (chain, draw) float64 64kB 0.9302 1.0 ... 0.9067 0.7528
          diverging        (chain, draw) bool 8kB False False False ... False False
          is_accepted      (chain, draw) bool 8kB True True True ... True True True
          n_steps          (chain, draw) int32 32kB 7 3 3 3 3 7 7 3 ... 7 7 7 3 3 7 7
          step_size        (chain, draw) float64 64kB 0.545 0.545 0.545 ... nan nan
          target_log_prob  (chain, draw) float64 64kB -142.4 -139.5 ... -140.7 -144.1
          tune             (chain, draw) float64 64kB 0.0 0.0 0.0 0.0 ... nan nan nan
      Attributes:
          created_at:                  2024-06-02T15:41:52.353603+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452
      Attributes:
          created_at:                  2024-06-02T15:41:52.350361+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

NumPyro

numpyro_nuts_idata = model.fit(inference_method="numpyro_nuts")
numpyro_nuts_idata
sample: 100%|██████████| 1500/1500 [00:06<00:00, 242.04it/s]
arviz.InferenceData
    • <xarray.Dataset> Size: 200kB
      Dimensions:    (chain: 8, draw: 1000)
      Coordinates:
        * chain      (chain) int64 64B 0 1 2 3 4 5 6 7
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          Intercept  (chain, draw) float64 64kB -0.01687 0.06615 ... 0.1263 0.03044
          sigma      (chain, draw) float64 64kB 0.965 0.8374 1.078 ... 1.002 0.8794
          x          (chain, draw) float64 64kB 0.2405 0.4685 0.2349 ... 0.3402 0.3522
      Attributes:
          created_at:                  2024-06-02T15:42:01.224796+00:00
          arviz_version:               0.18.0
          inference_library:           numpyro
          inference_library_version:   0.15.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 400kB
      Dimensions:          (chain: 8, draw: 1000)
      Coordinates:
        * chain            (chain) int64 64B 0 1 2 3 4 5 6 7
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 64kB 0.9735 0.7476 ... 0.9667 0.8725
          diverging        (chain, draw) bool 8kB False False False ... False False
          energy           (chain, draw) float64 64kB 140.5 145.1 ... 140.7 141.8
          lp               (chain, draw) float64 64kB 140.1 141.2 ... 140.4 139.6
          n_steps          (chain, draw) int64 64kB 7 7 7 3 3 1 3 3 ... 11 3 3 3 3 3 3
          step_size        (chain, draw) float64 64kB 0.7685 0.7685 ... 0.8865 0.8865
          tree_depth       (chain, draw) int64 64kB 3 3 3 2 2 1 2 2 ... 4 2 2 2 2 2 2
      Attributes:
          created_at:                  2024-06-02T15:42:01.260288+00:00
          arviz_version:               0.18.0
          inference_library:           numpyro
          inference_library_version:   0.15.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452
      Attributes:
          created_at:                  2024-06-02T15:42:01.224796+00:00
          arviz_version:               0.18.0
          inference_library:           numpyro
          inference_library_version:   0.15.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

flowMC

flowmc_idata = model.fit(inference_method="flowmc_realnvp_hmc")
flowmc_idata
['n_dim', 'n_chains', 'n_local_steps', 'n_global_steps', 'n_loop', 'output_thinning', 'verbose']
Global Tuning: 100%|██████████| 5/5 [00:44<00:00,  8.89s/it]
Global Sampling: 100%|██████████| 5/5 [00:00<00:00, 25.89it/s]
arviz.InferenceData
    • <xarray.Dataset> Size: 244kB
      Dimensions:    (chain: 20, draw: 500)
      Coordinates:
        * chain      (chain) int64 160B 0 1 2 3 4 5 6 7 8 ... 12 13 14 15 16 17 18 19
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          Intercept  (chain, draw) float64 80kB 0.07083 0.06709 ... 0.06182 -0.04028
          sigma      (chain, draw) float64 80kB 0.9755 0.9504 0.9298 ... 0.8554 0.9118
          x          (chain, draw) float64 80kB 0.382 0.3589 0.2673 ... 0.4581 0.3594
      Attributes:
          created_at:                  2024-06-02T15:42:49.303545+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

    • <xarray.Dataset> Size: 2kB
      Dimensions:  (__obs__: 100)
      Coordinates:
        * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
      Data variables:
          y        (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452
      Attributes:
          created_at:                  2024-06-02T15:42:49.303545+00:00
          arviz_version:               0.18.0
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602

Sampler comparisons

With ArviZ, we can compare the inference result summaries of the samplers. Note: We can’t use az.compare as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised.

az.summary(blackjax_nuts_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 0.026 0.098 -0.152 0.206 0.003 0.003 796.0 648.0 1.01
sigma 0.945 0.070 0.817 1.074 0.002 0.002 970.0 759.0 1.00
x 0.355 0.103 0.157 0.532 0.003 0.002 1067.0 692.0 1.00
az.summary(tfp_nuts_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 0.024 0.096 -0.157 0.204 0.001 0.001 7048.0 5524.0 1.0
sigma 0.948 0.068 0.829 1.083 0.001 0.001 7933.0 5659.0 1.0
x 0.361 0.103 0.168 0.550 0.001 0.001 6986.0 5702.0 1.0
az.summary(numpyro_nuts_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 0.025 0.095 -0.162 0.196 0.001 0.001 7396.0 5859.0 1.0
sigma 0.946 0.068 0.819 1.075 0.001 0.001 7131.0 5580.0 1.0
x 0.361 0.106 0.171 0.569 0.001 0.001 7673.0 5905.0 1.0
az.summary(flowmc_idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 0.024 0.096 -0.149 0.207 0.003 0.002 876.0 615.0 1.02
sigma 0.947 0.067 0.822 1.066 0.001 0.001 5554.0 5920.0 1.00
x 0.361 0.104 0.161 0.550 0.001 0.001 5081.0 4653.0 1.00

Summary

Thanks to bayeux, we can use three different sampling backends and 10+ alternative MCMC methods in Bambi. Using these methods is as simple as passing the inference name to the inference_method of the fit method.

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Jun 02 2024

Python implementation: CPython
Python version       : 3.11.9
IPython version      : 8.24.0

arviz : 0.18.0
pandas: 2.2.2
bayeux: 0.1.12
bambi : 0.13.1.dev44+g55aac858.d20240602
numpy : 1.26.4

Watermark: 2.4.3