Interpret Advanced Usage

The interpret module provides a set of helper functions to aid the user in more advanced and complex analysis not covered within the comparisons, predictions, and slopes functions. These helper functions are data_grid and select_draws. The data_grid can be used to create a pairwise grid of data points for the user to pass to model.predict. Subsequently, select_draws is used to select the draws from the posterior (or posterior predictive) group of the InferenceData object returned by the predict method that correspond to the data points that “produced” that draw.

With access to the appropriately indexed draws, and data used to generate those draws, it enables for more complex analysis such as cross-comparisons and the choice of which model parameter to compute a quantity of interest for; among others. Additionally, the user has more control over the data passed to model.predict. Below, it will be demonstrated how to use these helper functions. First, to reproduce the results from the standard interpret API, and then to compute cross-comparisons.

import warnings

import arviz as az
import numpy as np
import pandas as pd

import bambi as bmb

from bambi.interpret.helpers import data_grid, select_draws

warnings.simplefilter(action='ignore', category=FutureWarning)

Zero Inflated Poisson

We will adopt the zero inflated Poisson (ZIP) model from the comparisons documentation to demonstrate the helper functions introduced above.

The ZIP model will be used to predict how many fish are caught by visitors at a state park using survey data. Many visitors catch zero fish, either because they did not fish at all, or because they were unlucky. We would like to explicitly model this bimodal behavior (zero versus non-zero) using a Zero Inflated Poisson model, and to compare how different inputs of interest \(w\) and other covariate values \(c\) are associated with the number of fish caught. The dataset contains data on 250 groups that went to a state park to fish. Each group was questioned about how many fish they caught (count), how many children were in the group (child), how many people were in the group (persons), if they used a live bait and whether or not they brought a camper to the park (camper).

fish_data = pd.read_csv("https://stats.idre.ucla.edu/stat/data/fish.csv")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["child"] = fish_data["child"].astype(np.int8)
fish_data["persons"] = fish_data["persons"].astype(np.int8)
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"])
fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    family='zero_inflated_poisson'
)

fish_idata = fish_model.fit(random_seed=1234)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [psi, Intercept, livebait, camper, persons, child]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.

Create a grid of data

data_grid allows you to create a pairwise grid, also known as a cross-join or cartesian product, of data using the covariates passed to the conditional and the optional variable parameter. Covariates not passed to conditional, but are terms in the Bambi model, are set to typical values (e.g., mean or mode). If you are coming from R, this function is partially inspired from the data_grid function in the R package modelr.

There are two ways to create a pairwise grid of data:

  1. user-provided values are passed as a dictionary to conditional where the keys are the names of the covariates and the values are the values to use in the grid.
  2. a list of covariates where the elements are the names of the covariates to use in the grid. As only the names of the covariates were passed, default values are computed to construct the grid.

Any unspecified covariates, i.e., covariates not passed to conditional but are terms in the Bambi model, are set to their “typical” values such as mean or mode depending on the data type of the covariate.

User-provided values

To construct a pairwise grid of data for specific covariate values, pass a dictionary to conditional. The values of the dictionary can be of type int, float, list, or np.ndarray.

conditional = {
    "camper": np.array([0, 1]),
    "persons": np.arange(1, 5, 1),
    "child": np.array([1, 2, 3]),
}
user_passed_grid = data_grid(fish_model, conditional)
user_passed_grid.query("camper == 0")
Default computed for unspecified variable: livebait
camper persons child livebait
0 0 1 1 1
1 0 1 2 1
2 0 1 3 1
3 0 2 1 1
4 0 2 2 1
5 0 2 3 1
6 0 3 1 1
7 0 3 2 1
8 0 3 3 1
9 0 4 1 1
10 0 4 2 1
11 0 4 3 1

Subsetting by camper = 0, it can be seen that a combination of all possible pairs of values from the dictionary (including the unspecified variable livebait) results in a dataframe containing every possible combination of values from the original sets. livebait has been set to 1 as this is the mode of the unspecified categorical variable.

Default values

Alternatively, a list of covariates can be passed to conditional where the elements are the names of the covariates to use in the grid. By doing this, you are telling interpret to compute default values for these covariates. The psuedocode below outlines the logic and functions used to compute these default values:

if is_numeric_dtype(x) or is_float_dtype(x):
    values = np.linspace(np.min(x), np.max(x), 50)

elif is_integer_dtype(x):
    values = np.quantile(x, np.linspace(0, 1, 5))

elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
    values = np.unique(x)
conditional = ["camper", "persons", "child"]
default_grid = data_grid(fish_model, conditional)

default_grid.shape, user_passed_grid.shape
Default computed for conditional variable: camper, persons, child
Default computed for unspecified variable: livebait
((32, 4), (24, 4))

Notice how the resulting length is different between the user passed and default grid. This is due to the fact that values for child range from 0 to 3 for the default grid.

default_grid.query("camper == 0")
camper persons child livebait
0 0 1 0 1
1 0 1 1 1
2 0 1 2 1
3 0 1 3 1
4 0 2 0 1
5 0 2 1 1
6 0 2 2 1
7 0 2 3 1
8 0 3 0 1
9 0 3 1 1
10 0 3 2 1
11 0 3 3 1
12 0 4 0 1
13 0 4 1 1
14 0 4 2 1
15 0 4 3 1

Compute comparisons

To use data_grid to help generate data in computing comparisons or slopes, additional data is passed to the optional variable parameter. The name variable is an abstraction for the comparisons parameter contrast and slopes parameter wrt. If you have used any of the interpret functions, these parameter names should be familiar and the use of data_grid should be analogous to comparisons, predictions, and slopes.

variable can also be passed user-provided data (as a dictionary), or a string indicating the name of the covariate of interest. If the latter, a default value will be computed. Additionally, if an argument is passed for variable, then the effect_type needs to be passed. This is because for comparisons and slopes an epsilon value eps needs to be determined to compute the centered and finite difference, respectively. You can also pass a value for eps as a kwarg.

conditional = {
    "camper": np.array([0, 1]),
    "persons": np.arange(1, 5, 1),
    "child": np.array([1, 2, 3, 4])
}
variable = "livebait"

grid = data_grid(fish_model, conditional, variable, effect_type="comparisons")
Default computed for contrast variable: livebait
idata_grid = fish_model.predict(fish_idata, data=grid, inplace=False)

Select draws conditional on data

The second helper function to aid in more advanced analysis is select_draws. This is a function that selects the posterior or posterior predictive draws from the ArviZ InferenceData object returned by model.predict given a conditional dictionary. The conditional dictionary represents the values that correspond to that draw.

For example, if we wanted to select posterior draws where livebait = [0, 1], then all we need to do is pass a dictionary where the key is the name of the covariate and the value is the value that we want to condition on (or select). The resulting InferenceData object will contain the draws that correspond to the data points where livebait = [0, 1]. Additionally, you must pass the InferenceData object returned by model.predict, the data used to generate the predictions, and the name of the data variable data_var you would like to select from the InferenceData posterior group. If you specified to return the posterior predictive samples by passing model.predict(..., kind="pps"), you can use this group instead of the posterior group by passing group="posterior_predictive".

Below, it is demonstrated how to compute comparisons for count_mean for the contrast livebait = [0, 1] using the posterior draws.

idata_grid
arviz.InferenceData
    • <xarray.Dataset> Size: 2MB
      Dimensions:       (chain: 4, draw: 1000, livebait_dim: 1, camper_dim: 1,
                         __obs__: 64)
      Coordinates:
        * chain         (chain) int64 32B 0 1 2 3
        * draw          (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * livebait_dim  (livebait_dim) <U1 4B '1'
        * camper_dim    (camper_dim) <U1 4B '1'
        * __obs__       (__obs__) int64 512B 0 1 2 3 4 5 6 7 ... 57 58 59 60 61 62 63
      Data variables:
          psi           (chain, draw) float64 32kB 0.4751 0.6612 ... 0.5565 0.5387
          Intercept     (chain, draw) float64 32kB -2.655 -2.48 ... -2.59 -2.517
          livebait      (chain, draw, livebait_dim) float64 32kB 1.936 1.784 ... 1.822
          camper        (chain, draw, camper_dim) float64 32kB 0.6152 ... 0.5638
          persons       (chain, draw) float64 32kB 0.864 0.8756 ... 0.8885 0.8896
          child         (chain, draw) float64 32kB -1.399 -1.272 ... -1.474 -1.467
          mu            (chain, draw, __obs__) float64 2MB 0.04118 0.2855 ... 0.08692
      Attributes:
          created_at:                  2025-09-28T17:53:04.333667+00:00
          arviz_version:               0.22.0
          inference_library:           pymc
          inference_library_version:   3.9.2+2907.g7a3db78e6
          sampling_time:               5.744408369064331
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev57+g7b2bb342c.d20250928

    • <xarray.Dataset> Size: 528kB
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables: (12/18)
          process_time_diff      (chain, draw) float64 32kB 0.00177 ... 0.001431
          index_in_trajectory    (chain, draw) int64 32kB -2 4 3 6 -6 ... 5 -4 5 3 13
          max_energy_error       (chain, draw) float64 32kB 2.184 1.865 ... 1.111
          perf_counter_start     (chain, draw) float64 32kB 2.045e+05 ... 2.045e+05
          reached_max_treedepth  (chain, draw) bool 4kB False False ... False False
          energy_error           (chain, draw) float64 32kB -0.1267 0.02556 ... 1.111
          ...                     ...
          n_steps                (chain, draw) float64 32kB 7.0 7.0 7.0 ... 7.0 15.0
          step_size_bar          (chain, draw) float64 32kB 0.463 0.463 ... 0.4197
          energy                 (chain, draw) float64 32kB 762.8 759.5 ... 755.8
          diverging              (chain, draw) bool 4kB False False ... False False
          perf_counter_diff      (chain, draw) float64 32kB 0.001946 ... 0.001552
          smallest_eigval        (chain, draw) float64 32kB nan nan nan ... nan nan
      Attributes:
          created_at:                  2025-09-28T17:53:04.353984+00:00
          arviz_version:               0.22.0
          inference_library:           pymc
          inference_library_version:   3.9.2+2907.g7a3db78e6
          sampling_time:               5.744408369064331
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev57+g7b2bb342c.d20250928

    • <xarray.Dataset> Size: 4kB
      Dimensions:  (__obs__: 250)
      Coordinates:
        * __obs__  (__obs__) int64 2kB 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249
      Data variables:
          count    (__obs__) int64 2kB 0 0 0 0 1 0 0 0 0 1 0 ... 4 1 1 0 1 0 0 0 0 0 0
      Attributes:
          created_at:                  2025-09-28T17:53:04.359378+00:00
          arviz_version:               0.22.0
          inference_library:           pymc
          inference_library_version:   3.9.2+2907.g7a3db78e6
          modeling_interface:          bambi
          modeling_interface_version:  0.14.1.dev57+g7b2bb342c.d20250928

draw_1 = select_draws(idata_grid, grid, {"livebait": 0}, "mu")
draw_1 = select_draws(idata_grid, grid, {"livebait": 0}, "mu")
draw_2 = select_draws(idata_grid, grid, {"livebait": 1}, "mu")

comparison_mean = (draw_2 - draw_1).mean(("chain", "draw"))
comparison_hdi = az.hdi(draw_2 - draw_1)

comparison_df = pd.DataFrame(
    {
        "mean": comparison_mean.values,
        "hdi_low": comparison_hdi.sel(hdi="lower")["mu"].values,
        "hdi_high": comparison_hdi.sel(hdi="higher")["mu"].values,
    }
)
comparison_df.head(10)
mean hdi_low hdi_high
0 0.213995 0.143976 0.286619
1 0.053701 0.031074 0.079068
2 0.013598 0.006305 0.022216
3 0.003475 0.001268 0.006288
4 0.511683 0.373739 0.667900
5 0.128343 0.078841 0.183933
6 0.032483 0.015540 0.051750
7 0.008296 0.003144 0.014865
8 1.225939 0.921290 1.553605
9 0.307349 0.194784 0.436108

We can compare this comparison with bmb.interpret.comparisons.

summary_df = bmb.interpret.comparisons(
    fish_model,
    fish_idata,
    contrast={"livebait": [0, 1]},
    conditional=conditional
)
summary_df.head(10)
term estimate_type value camper persons child estimate lower_3.0% upper_97.0%
0 livebait diff (0, 1) 0 1 1 0.213995 0.143976 0.286619
1 livebait diff (0, 1) 0 1 2 0.053701 0.031074 0.079068
2 livebait diff (0, 1) 0 1 3 0.013598 0.006305 0.022216
3 livebait diff (0, 1) 0 1 4 0.003475 0.001268 0.006288
4 livebait diff (0, 1) 0 2 1 0.511683 0.373739 0.667900
5 livebait diff (0, 1) 0 2 2 0.128343 0.078841 0.183933
6 livebait diff (0, 1) 0 2 3 0.032483 0.015540 0.051750
7 livebait diff (0, 1) 0 2 4 0.008296 0.003144 0.014865
8 livebait diff (0, 1) 0 3 1 1.225939 0.921290 1.553605
9 livebait diff (0, 1) 0 3 2 0.307349 0.194784 0.436108

Albeit the other information in the summary_df, the columns estimate, lower_3.0%, upper_97.0% are identical.

Cross comparisons

Computing a cross-comparison is useful for when we want to compare contrasts when two (or more) predictors change at the same time. Cross-comparisons are currently not supported in the comparisons function, but we can use select_draws to compute them. For example, imagine we are interested in computing the cross-comparison between the two rows below.

summary_df.iloc[:2]
term estimate_type value camper persons child estimate lower_3.0% upper_97.0%
0 livebait diff (0, 1) 0 1 1 0.213995 0.143976 0.286619
1 livebait diff (0, 1) 0 1 2 0.053701 0.031074 0.079068

The cross-comparison amounts to first computing the comparison for row 0, given below, and can be verified by looking at the estimate in summary_df.

cond_10 = {
    "camper": 0,
    "persons": 1,
    "child": 1,
    "livebait": 0 
}

cond_11 = {
    "camper": 0,
    "persons": 1,
    "child": 1,
    "livebait": 1
}

draws_10 = select_draws(idata_grid, grid, cond_10, "mu")
draws_11 = select_draws(idata_grid, grid, cond_11, "mu")

(draws_11 - draws_10).mean(("chain", "draw")).item()
0.21399545475519965

Next, we need to compute the comparison for row 1.

cond_20 = {
    "camper": 0,
    "persons": 1,
    "child": 2,
    "livebait": 0
}

cond_21 = {
    "camper": 0,
    "persons": 1,
    "child": 2,
    "livebait": 1
}

draws_20 = select_draws(idata_grid, grid, cond_20, "mu")
draws_21 = select_draws(idata_grid, grid, cond_21, "mu")
(draws_21 - draws_20).mean(("chain", "draw")).item()
0.053701205869420204

Next, we compute the “first level” comparisons (diff_1 and diff_2). Subsequently, we compute the difference between these two differences to obtain the cross-comparison.

diff_1 = (draws_11 - draws_10)
diff_2 = (draws_21 - draws_20)

cross_comparison = (diff_2 - diff_1).mean(("chain", "draw")).item()
cross_comparison
-0.16029424888577945

To verify this is correct, we can check by performing the cross-comparison directly on the summary_df.

summary_df.iloc[1]["estimate"] - summary_df.iloc[0]["estimate"]
np.float64(-0.16029424888577945)

Summary

In this notebook, the interpret helper functions data_grid and select_draws were introduced and it was demonstrated how they can be used to compute pairwise grids of data and cross-comparisons. With these functions, it is left to the user to generate their grids of data and quantities of interest allowing for more flexibility and control over the type of data passed to model.predict and the quantities of interest computed.

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Sep 28 2025

Python implementation: CPython
Python version       : 3.13.7
IPython version      : 9.4.0

numpy : 2.3.3
arviz : 0.22.0
bambi : 0.14.1.dev57+g7b2bb342c.d20250928
pandas: 2.3.2

Watermark: 2.5.0