Interpret Advanced Usage

The interpret module provides a set of helper functions to aid the user in more advanced and complex analysis not covered within the comparisons, predictions, and slopes functions. These helper functions are data_grid and select_draws. The data_grid can be used to create a pairwise grid of data points for the user to pass to model.predict. Subsequently, select_draws is used to select the draws from the posterior (or posterior predictive) group of the InferenceData object returned by the predict method that correspond to the data points that “produced” that draw.

With access to the appropriately indexed draws, and data used to generate those draws, it enables for more complex analysis such as cross-comparisons and the choice of which model parameter to compute a quantity of interest for; among others. Additionally, the user has more control over the data passed to model.predict. Below, it will be demonstrated how to use these helper functions. First, to reproduce the results from the standard interpret API, and then to compute cross-comparisons.

import warnings

import arviz as az
import numpy as np
import pandas as pd

import bambi as bmb

from bambi.interpret.helpers import data_grid, select_draws

warnings.simplefilter(action='ignore', category=FutureWarning)

Zero Inflated Poisson

We will adopt the zero inflated Poisson (ZIP) model from the comparisons documentation to demonstrate the helper functions introduced above.

The ZIP model will be used to predict how many fish are caught by visitors at a state park using survey data. Many visitors catch zero fish, either because they did not fish at all, or because they were unlucky. We would like to explicitly model this bimodal behavior (zero versus non-zero) using a Zero Inflated Poisson model, and to compare how different inputs of interest \(w\) and other covariate values \(c\) are associated with the number of fish caught. The dataset contains data on 250 groups that went to a state park to fish. Each group was questioned about how many fish they caught (count), how many children were in the group (child), how many people were in the group (persons), if they used a live bait and whether or not they brought a camper to the park (camper).

fish_data = pd.read_csv("https://stats.idre.ucla.edu/stat/data/fish.csv")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["child"] = fish_data["child"].astype(np.int8)
fish_data["persons"] = fish_data["persons"].astype(np.int8)
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"])
fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    family='zero_inflated_poisson'
)

fish_idata = fish_model.fit(random_seed=1234)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [psi, Intercept, livebait, camper, persons, child]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 8 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Create a grid of data

data_grid allows you to create a pairwise grid, also known as a cross-join or cartesian product, of data using the covariates passed to the conditional and the optional variable parameter. Covariates not passed to conditional, but are terms in the Bambi model, are set to typical values (e.g., mean or mode). If you are coming from R, this function is partially inspired from the data_grid function in {modelr}.

There are two ways to create a pairwise grid of data:

  1. user-provided values are passed as a dictionary to conditional where the keys are the names of the covariates and the values are the values to use in the grid.
  2. a list of covariates where the elements are the names of the covariates to use in the grid. As only the names of the covariates were passed, default values are computed to construct the grid.

Any unspecified covariates, i.e., covariates not passed to conditional but are terms in the Bambi model, are set to their “typical” values such as mean or mode depending on the data type of the covariate.

User-provided values

To construct a pairwise grid of data for specific covariate values, pass a dictionary to conditional. The values of the dictionary can be of type int, float, list, or np.ndarray.

conditional = {
    "camper": np.array([0, 1]),
    "persons": np.arange(1, 5, 1),
    "child": np.array([1, 2, 3]),
}
user_passed_grid = data_grid(fish_model, conditional)
user_passed_grid.query("camper == 0")
Default computed for unspecified variable: livebait
camper persons child livebait
0 0 1 1 1
1 0 1 2 1
2 0 1 3 1
3 0 2 1 1
4 0 2 2 1
5 0 2 3 1
6 0 3 1 1
7 0 3 2 1
8 0 3 3 1
9 0 4 1 1
10 0 4 2 1
11 0 4 3 1

Subsetting by camper = 0, it can be seen that a combination of all possible pairs of values from the dictionary (including the unspecified variable livebait) results in a dataframe containing every possible combination of values from the original sets. livebait has been set to 1 as this is the mode of the unspecified categorical variable.

Default values

Alternatively, a list of covariates can be passed to conditional where the elements are the names of the covariates to use in the grid. By doing this, you are telling interpret to compute default values for these covariates. The psuedocode below outlines the logic and functions used to compute these default values:

if is_numeric_dtype(x) or is_float_dtype(x):
    values = np.linspace(np.min(x), np.max(x), 50)

elif is_integer_dtype(x):
    values = np.quantile(x, np.linspace(0, 1, 5))

elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
    values = np.unique(x)
conditional = ["camper", "persons", "child"]
default_grid = data_grid(fish_model, conditional)

default_grid.shape, user_passed_grid.shape
Default computed for conditional variable: camper, persons, child
Default computed for unspecified variable: livebait
((32, 4), (24, 4))

Notice how the resulting length is different between the user passed and default grid. This is due to the fact that values for child range from 0 to 3 for the default grid.

default_grid.query("camper == 0")
camper persons child livebait
0 0 1 0 1
1 0 1 1 1
2 0 1 2 1
3 0 1 3 1
4 0 2 0 1
5 0 2 1 1
6 0 2 2 1
7 0 2 3 1
8 0 3 0 1
9 0 3 1 1
10 0 3 2 1
11 0 3 3 1
12 0 4 0 1
13 0 4 1 1
14 0 4 2 1
15 0 4 3 1

Compute comparisons

To use data_grid to help generate data in computing comparisons or slopes, additional data is passed to the optional variable parameter. The name variable is an abstraction for the comparisons parameter contrast and slopes parameter wrt. If you have used any of the interpret functions, these parameter names should be familiar and the use of data_grid should be analogous to comparisons, predictions, and slopes.

variable can also be passed user-provided data (as a dictionary), or a string indicating the name of the covariate of interest. If the latter, a default value will be computed. Additionally, if an argument is passed for variable, then the effect_type needs to be passed. This is because for comparisons and slopes an epsilon value eps needs to be determined to compute the centered and finite difference, respectively. You can also pass a value for eps as a kwarg.

conditional = {
    "camper": np.array([0, 1]),
    "persons": np.arange(1, 5, 1),
    "child": np.array([1, 2, 3, 4])
}
variable = "livebait"

grid = data_grid(fish_model, conditional, variable, effect_type="comparisons")
Default computed for contrast variable: livebait
idata_grid = fish_model.predict(fish_idata, data=grid, inplace=False)

Select draws conditional on data

The second helper function to aid in more advanced analysis is select_draws. This is a function that selects the posterior or posterior predictive draws from the ArviZ InferenceData object returned by model.predict given a conditional dictionary. The conditional dictionary represents the values that correspond to that draw.

For example, if we wanted to select posterior draws where livebait = [0, 1], then all we need to do is pass a dictionary where the key is the name of the covariate and the value is the value that we want to condition on (or select). The resulting InferenceData object will contain the draws that correspond to the data points where livebait = [0, 1]. Additionally, you must pass the InferenceData object returned by model.predict, the data used to generate the predictions, and the name of the data variable data_var you would like to select from the InferenceData posterior group. If you specified to return the posterior predictive samples by passing model.predict(..., kind="pps"), you can use this group instead of the posterior group by passing group="posterior_predictive".

Below, it is demonstrated how to compute comparisons for count_mean for the contrast livebait = [0, 1] using the posterior draws.

idata_grid
arviz.InferenceData
    • <xarray.Dataset> Size: 1MB
      Dimensions:       (chain: 2, draw: 1000, camper_dim: 1, livebait_dim: 1,
                         __obs__: 64)
      Coordinates:
        * chain         (chain) int64 16B 0 1
        * draw          (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * camper_dim    (camper_dim) <U1 4B '1'
        * livebait_dim  (livebait_dim) <U1 4B '1'
        * __obs__       (__obs__) int64 512B 0 1 2 3 4 5 6 7 ... 57 58 59 60 61 62 63
      Data variables:
          Intercept     (chain, draw) float64 16kB -2.319 -2.441 ... -2.705 -2.571
          camper        (chain, draw, camper_dim) float64 16kB 0.7341 0.6539 ... 0.763
          child         (chain, draw) float64 16kB -1.232 -1.251 ... -1.437 -1.344
          livebait      (chain, draw, livebait_dim) float64 16kB 1.509 1.634 ... 1.7
          persons       (chain, draw) float64 16kB 0.8448 0.8763 ... 0.903 0.8508
          psi           (chain, draw) float64 16kB 0.6017 0.6008 ... 0.6933 0.5425
          mu            (chain, draw, __obs__) float64 1MB 0.06677 0.3021 ... 0.1248
      Attributes:
          created_at:                  2024-05-26T21:58:43.655562+00:00
          arviz_version:               0.18.0
          inference_library:           pymc
          inference_library_version:   5.15.0+23.g19be124e
          sampling_time:               8.18968677520752
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev39+gb7d6a6cb

    • <xarray.Dataset> Size: 252kB
      Dimensions:                (chain: 2, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 16B 0 1
        * draw                   (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables: (12/17)
          acceptance_rate        (chain, draw) float64 16kB 0.4852 0.4691 ... 0.4878
          diverging              (chain, draw) bool 2kB False False ... False False
          energy                 (chain, draw) float64 16kB 756.7 754.8 ... 755.2
          energy_error           (chain, draw) float64 16kB 0.1849 -0.1576 ... 0.6626
          index_in_trajectory    (chain, draw) int64 16kB 6 -2 -2 5 3 ... 5 -2 -5 -5 6
          largest_eigval         (chain, draw) float64 16kB nan nan nan ... nan nan
          ...                     ...
          process_time_diff      (chain, draw) float64 16kB 0.004515 ... 0.002555
          reached_max_treedepth  (chain, draw) bool 2kB False False ... False False
          smallest_eigval        (chain, draw) float64 16kB nan nan nan ... nan nan
          step_size              (chain, draw) float64 16kB 0.4493 0.4493 ... 0.2491
          step_size_bar          (chain, draw) float64 16kB 0.4092 0.4092 ... 0.3806
          tree_depth             (chain, draw) int64 16kB 4 3 3 4 3 3 ... 4 3 3 3 4 3
      Attributes:
          created_at:                  2024-05-26T21:58:43.673049+00:00
          arviz_version:               0.18.0
          inference_library:           pymc
          inference_library_version:   5.15.0+23.g19be124e
          sampling_time:               8.18968677520752
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev39+gb7d6a6cb

    • <xarray.Dataset> Size: 4kB
      Dimensions:  (__obs__: 250)
      Coordinates:
        * __obs__  (__obs__) int64 2kB 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249
      Data variables:
          count    (__obs__) int64 2kB 0 0 0 0 1 0 0 0 0 1 0 ... 4 1 1 0 1 0 0 0 0 0 0
      Attributes:
          created_at:                  2024-05-26T21:58:43.678881+00:00
          arviz_version:               0.18.0
          inference_library:           pymc
          inference_library_version:   5.15.0+23.g19be124e
          modeling_interface:          bambi
          modeling_interface_version:  0.13.1.dev39+gb7d6a6cb

draw_1 = select_draws(idata_grid, grid, {"livebait": 0}, "mu")
draw_1 = select_draws(idata_grid, grid, {"livebait": 0}, "mu")
draw_2 = select_draws(idata_grid, grid, {"livebait": 1}, "mu")

comparison_mean = (draw_2 - draw_1).mean(("chain", "draw"))
comparison_hdi = az.hdi(draw_2 - draw_1)

comparison_df = pd.DataFrame(
    {
        "mean": comparison_mean.values,
        "hdi_low": comparison_hdi.sel(hdi="lower")["mu"].values,
        "hdi_high": comparison_hdi.sel(hdi="higher")["mu"].values,
    }
)
comparison_df.head(10)
mean hdi_low hdi_high
0 0.216038 0.150384 0.294501
1 0.054384 0.029894 0.077987
2 0.013811 0.005794 0.021496
3 0.003539 0.001142 0.006083
4 0.515839 0.375190 0.670931
5 0.129765 0.078764 0.182346
6 0.032932 0.014961 0.050478
7 0.008432 0.002665 0.014138
8 1.234183 0.926056 1.553555
9 0.310259 0.195251 0.425951

We can compare this comparison with bmb.interpret.comparisons.

summary_df = bmb.interpret.comparisons(
    fish_model,
    fish_idata,
    contrast={"livebait": [0, 1]},
    conditional=conditional
)
summary_df.head(10)
term estimate_type value camper persons child estimate lower_3.0% upper_97.0%
0 livebait diff (0, 1) 0 1 1 0.216038 0.150384 0.294501
1 livebait diff (0, 1) 0 1 2 0.054384 0.029894 0.077987
2 livebait diff (0, 1) 0 1 3 0.013811 0.005794 0.021496
3 livebait diff (0, 1) 0 1 4 0.003539 0.001142 0.006083
4 livebait diff (0, 1) 0 2 1 0.515839 0.375190 0.670931
5 livebait diff (0, 1) 0 2 2 0.129765 0.078764 0.182346
6 livebait diff (0, 1) 0 2 3 0.032932 0.014961 0.050478
7 livebait diff (0, 1) 0 2 4 0.008432 0.002665 0.014138
8 livebait diff (0, 1) 0 3 1 1.234183 0.926056 1.553555
9 livebait diff (0, 1) 0 3 2 0.310259 0.195251 0.425951

Albeit the other information in the summary_df, the columns estimate, lower_3.0%, upper_97.0% are identical.

Cross comparisons

Computing a cross-comparison is useful for when we want to compare contrasts when two (or more) predictors change at the same time. Cross-comparisons are currently not supported in the comparisons function, but we can use select_draws to compute them. For example, imagine we are interested in computing the cross-comparison between the two rows below.

summary_df.iloc[:2]
term estimate_type value camper persons child estimate lower_3.0% upper_97.0%
0 livebait diff (0, 1) 0 1 1 0.216038 0.150384 0.294501
1 livebait diff (0, 1) 0 1 2 0.054384 0.029894 0.077987

The cross-comparison amounts to first computing the comparison for row 0, given below, and can be verified by looking at the estimate in summary_df.

cond_10 = {
    "camper": 0,
    "persons": 1,
    "child": 1,
    "livebait": 0 
}

cond_11 = {
    "camper": 0,
    "persons": 1,
    "child": 1,
    "livebait": 1
}

draws_10 = select_draws(idata_grid, grid, cond_10, "mu")
draws_11 = select_draws(idata_grid, grid, cond_11, "mu")

(draws_11 - draws_10).mean(("chain", "draw")).item()
0.21603768204145285

Next, we need to compute the comparison for row 1.

cond_20 = {
    "camper": 0,
    "persons": 1,
    "child": 2,
    "livebait": 0
}

cond_21 = {
    "camper": 0,
    "persons": 1,
    "child": 2,
    "livebait": 1
}

draws_20 = select_draws(idata_grid, grid, cond_20, "mu")
draws_21 = select_draws(idata_grid, grid, cond_21, "mu")
(draws_21 - draws_20).mean(("chain", "draw")).item()
0.0543836038546379

Next, we compute the “first level” comparisons (diff_1 and diff_2). Subsequently, we compute the difference between these two differences to obtain the cross-comparison.

diff_1 = (draws_11 - draws_10)
diff_2 = (draws_21 - draws_20)

cross_comparison = (diff_2 - diff_1).mean(("chain", "draw")).item()
cross_comparison
-0.16165407818681496

To verify this is correct, we can check by performing the cross-comparison directly on the summary_df.

summary_df.iloc[1]["estimate"] - summary_df.iloc[0]["estimate"]
-0.16165407818681496

Summary

In this notebook, the interpret helper functions data_grid and select_draws were introduced and it was demonstrated how they can be used to compute pairwise grids of data and cross-comparisons. With these functions, it is left to the user to generate their grids of data and quantities of interest allowing for more flexibility and control over the type of data passed to model.predict and the quantities of interest computed.

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun May 26 2024

Python implementation: CPython
Python version       : 3.11.9
IPython version      : 8.24.0

numpy : 1.26.4
arviz : 0.18.0
bambi : 0.13.1.dev39+gb7d6a6cb
pandas: 2.2.2

Watermark: 2.4.3