Interpret Advanced Usage

The interpret module provides a set of helper functions to aid the user in more advanced and complex analysis not covered within the comparisons, predictions, and slopes functions. These helper functions are data_grid and select_draws. The data_grid can be used to create a pairwise grid of data points for the user to pass to model.predict. Subsequently, select_draws is used to select the draws from the posterior (or posterior predictive) group of the InferenceData object returned by the predict method that correspond to the data points that “produced” that draw.

With access to the appropriately indexed draws, and data used to generate those draws, it enables for more complex analysis such as cross-comparisons and the choice of which model parameter to compute a quantity of interest for; among others. Additionally, the user has more control over the data passed to model.predict. Below, it will be demonstrated how to use these helper functions. First, to reproduce the results from the standard interpret API, and then to compute cross-comparisons.

import warnings

import arviz as az
import numpy as np
import pandas as pd

import bambi as bmb

from bambi.interpret.helpers import data_grid, select_draws

warnings.simplefilter(action='ignore', category=FutureWarning)
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

Zero Inflated Poisson

We will adopt the zero inflated Poisson (ZIP) model from the comparisons documentation to demonstrate the helper functions introduced above.

The ZIP model will be used to predict how many fish are caught by visitors at a state park using survey data. Many visitors catch zero fish, either because they did not fish at all, or because they were unlucky. We would like to explicitly model this bimodal behavior (zero versus non-zero) using a Zero Inflated Poisson model, and to compare how different inputs of interest \(w\) and other covariate values \(c\) are associated with the number of fish caught. The dataset contains data on 250 groups that went to a state park to fish. Each group was questioned about how many fish they caught (count), how many children were in the group (child), how many people were in the group (persons), if they used a live bait and whether or not they brought a camper to the park (camper).

fish_data = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["child"] = fish_data["child"].astype(np.int8)
fish_data["persons"] = fish_data["persons"].astype(np.int8)
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"])
fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    family='zero_inflated_poisson'
)

fish_idata = fish_model.fit(random_seed=1234)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [count_psi, Intercept, livebait, camper, persons, child]
100.00% [8000/8000 00:02<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.

Create a grid of data

data_grid allows you to create a pairwise grid, also known as a cross-join or cartesian product, of data using the covariates passed to the conditional and the optional variable parameter. Covariates not passed to conditional, but are terms in the Bambi model, are set to typical values (e.g., mean or mode). If you are coming from R, this function is partially inspired from the data_grid function in {modelr}.

There are two ways to create a pairwise grid of data:

  1. user-provided values are passed as a dictionary to conditional where the keys are the names of the covariates and the values are the values to use in the grid.
  2. a list of covariates where the elements are the names of the covariates to use in the grid. As only the names of the covariates were passed, default values are computed to construct the grid.

Any unspecified covariates, i.e., covariates not passed to conditional but are terms in the Bambi model, are set to their “typical” values such as mean or mode depending on the data type of the covariate.

User-provided values

To construct a pairwise grid of data for specific covariate values, pass a dictionary to conditional. The values of the dictionary can be of type int, float, list, or np.ndarray.

conditional = {
    "camper": np.array([0, 1]),
    "persons": np.arange(1, 5, 1),
    "child": np.array([1, 2, 3]),
}
user_passed_grid = data_grid(fish_model, conditional)
user_passed_grid.query("camper == 0")
Default computed for unspecified variable: livebait
camper persons child livebait
0 0 1 1 1.0
1 0 1 2 1.0
2 0 1 3 1.0
3 0 2 1 1.0
4 0 2 2 1.0
5 0 2 3 1.0
6 0 3 1 1.0
7 0 3 2 1.0
8 0 3 3 1.0
9 0 4 1 1.0
10 0 4 2 1.0
11 0 4 3 1.0

Subsetting by camper = 0, it can be seen that a combination of all possible pairs of values from the dictionary (including the unspecified variable livebait) results in a dataframe containing every possible combination of values from the original sets. livebait has been set to 1 as this is the mode of the unspecified categorical variable.

Default values

Alternatively, a list of covariates can be passed to conditional where the elements are the names of the covariates to use in the grid. By doing this, you are telling interpret to compute default values for these covariates. The psuedocode below outlines the logic and functions used to compute these default values:

if is_numeric_dtype(x) or is_float_dtype(x):
    values = np.linspace(np.min(x), np.max(x), 50)

elif is_integer_dtype(x):
    values = np.quantile(x, np.linspace(0, 1, 5))

elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
    values = np.unique(x)
conditional = ["camper", "persons", "child"]
default_grid = data_grid(fish_model, conditional)

default_grid.shape, user_passed_grid.shape
Default computed for conditional variable: camper, persons, child
Default computed for unspecified variable: livebait
((32, 4), (24, 4))

Notice how the resulting length is different between the user passed and default grid. This is due to the fact that values for child range from 0 to 3 for the default grid.

default_grid.query("camper == 0")
camper persons child livebait
0 0.0 1 0 1.0
1 0.0 1 1 1.0
2 0.0 1 2 1.0
3 0.0 1 3 1.0
4 0.0 2 0 1.0
5 0.0 2 1 1.0
6 0.0 2 2 1.0
7 0.0 2 3 1.0
8 0.0 3 0 1.0
9 0.0 3 1 1.0
10 0.0 3 2 1.0
11 0.0 3 3 1.0
12 0.0 4 0 1.0
13 0.0 4 1 1.0
14 0.0 4 2 1.0
15 0.0 4 3 1.0

Compute comparisons

To use data_grid to help generate data in computing comparisons or slopes, additional data is passed to the optional variable parameter. The name variable is an abstraction for the comparisons parameter contrast and slopes parameter wrt. If you have used any of the interpret functions, these parameter names should be familiar and the use of data_grid should be analogous to comparisons, predictions, and slopes.

variable can also be passed user-provided data (as a dictionary), or a string indicating the name of the covariate of interest. If the latter, a default value will be computed. Additionally, if an argument is passed for variable, then the effect_type needs to be passed. This is because for comparisons and slopes an epsilon value eps needs to be determined to compute the centered and finite difference, respectively. You can also pass a value for eps as a kwarg.

conditional = {
    "camper": np.array([0, 1]),
    "persons": np.arange(1, 5, 1),
    "child": np.array([1, 2, 3, 4])
}
variable = "livebait"

grid = data_grid(fish_model, conditional, variable, effect_type="comparisons")
Default computed for contrast variable: livebait
idata_grid = fish_model.predict(fish_idata, data=grid, inplace=False)

Select draws conditional on data

The second helper function to aid in more advanced analysis is select_draws. This is a function that selects the posterior or posterior predictive draws from the ArviZ InferenceData object returned by model.predict given a conditional dictionary. The conditional dictionary represents the values that correspond to that draw.

For example, if we wanted to select posterior draws where livebait = [0, 1], then all we need to do is pass a dictionary where the key is the name of the covariate and the value is the value that we want to condition on (or select). The resulting InferenceData object will contain the draws that correspond to the data points where livebait = [0, 1]. Additionally, you must pass the InferenceData object returned by model.predict, the data used to generate the predictions, and the name of the data variable data_var you would like to select from the InferenceData posterior group. If you specified to return the posterior predictive samples by passing model.predict(..., kind="pps"), you can use this group instead of the posterior group by passing group="posterior_predictive".

Below, it is demonstrated how to compute comparisons for count_mean for the contrast livebait = [0, 1] using the posterior draws.

idata_grid
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:       (chain: 4, draw: 1000, livebait_dim: 1, camper_dim: 1,
                         count_obs: 64)
      Coordinates:
        * chain         (chain) int64 0 1 2 3
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * livebait_dim  (livebait_dim) <U3 '1.0'
        * camper_dim    (camper_dim) <U3 '1.0'
        * count_obs     (count_obs) int64 0 1 2 3 4 5 6 7 ... 56 57 58 59 60 61 62 63
      Data variables:
          Intercept     (chain, draw) float64 -2.454 -2.31 -2.91 ... -2.652 -2.887
          livebait      (chain, draw, livebait_dim) float64 1.629 1.58 ... 1.799 1.967
          camper        (chain, draw, camper_dim) float64 0.7037 0.7089 ... 0.7128
          persons       (chain, draw) float64 0.8707 0.8369 0.9457 ... 0.8847 0.912
          child         (chain, draw) float64 -1.345 -1.412 -1.418 ... -1.293 -1.573
          count_psi     (chain, draw) float64 0.6311 0.6201 0.6342 ... 0.6768 0.5745
          count_mean    (chain, draw, count_obs) float64 0.05349 0.2728 ... 0.05777
      Attributes:
          created_at:                  2023-12-05T18:56:31.591639
          arviz_version:               0.16.1
          inference_library:           pymc
          inference_library_version:   5.8.1
          sampling_time:               2.6078336238861084
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.13.0.dev0

    • <xarray.Dataset>
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1 2 3
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          max_energy_error       (chain, draw) float64 0.595 1.645 ... -0.2578 0.4279
          reached_max_treedepth  (chain, draw) bool False False False ... False False
          process_time_diff      (chain, draw) float64 0.000885 0.000874 ... 0.001667
          diverging              (chain, draw) bool False False False ... False False
          n_steps                (chain, draw) float64 7.0 7.0 7.0 ... 7.0 11.0 15.0
          lp                     (chain, draw) float64 -750.1 -750.2 ... -751.9 -754.0
          ...                     ...
          largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan
          step_size              (chain, draw) float64 0.415 0.415 ... 0.3918 0.3918
          smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan
          perf_counter_start     (chain, draw) float64 5.346e+04 ... 5.346e+04
          acceptance_rate        (chain, draw) float64 0.7512 0.5492 ... 0.9919 0.8264
          index_in_trajectory    (chain, draw) int64 6 4 -6 -5 -2 4 ... 7 1 4 4 -1 9
      Attributes:
          created_at:                  2023-12-05T18:56:31.597938
          arviz_version:               0.16.1
          inference_library:           pymc
          inference_library_version:   5.8.1
          sampling_time:               2.6078336238861084
          tuning_steps:                1000
          modeling_interface:          bambi
          modeling_interface_version:  0.13.0.dev0

    • <xarray.Dataset>
      Dimensions:    (count_obs: 250)
      Coordinates:
        * count_obs  (count_obs) int64 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249
      Data variables:
          count      (count_obs) int64 0 0 0 0 1 0 0 0 0 1 0 ... 4 1 1 0 1 0 0 0 0 0 0
      Attributes:
          created_at:                  2023-12-05T18:56:31.600135
          arviz_version:               0.16.1
          inference_library:           pymc
          inference_library_version:   5.8.1
          modeling_interface:          bambi
          modeling_interface_version:  0.13.0.dev0

draw_1 = select_draws(idata_grid, grid, {"livebait": 0}, "count_mean")
draw_1 = select_draws(idata_grid, grid, {"livebait": 0}, "count_mean")
draw_2 = select_draws(idata_grid, grid, {"livebait": 1}, "count_mean")

comparison_mean = (draw_2 - draw_1).mean(("chain", "draw"))
comparison_hdi = az.hdi(draw_2 - draw_1)

comparison_df = pd.DataFrame(
    {
        "mean": comparison_mean.values,
        "hdi_low": comparison_hdi.sel(hdi="lower")["count_mean"].values,
        "hdi_high": comparison_hdi.sel(hdi="higher")["count_mean"].values,
    }
)
comparison_df.head(10)
mean hdi_low hdi_high
0 0.214363 0.144309 0.287735
1 0.053678 0.029533 0.077615
2 0.013558 0.006332 0.021971
3 0.003454 0.001132 0.006040
4 0.512709 0.369741 0.661034
5 0.128316 0.077068 0.181741
6 0.032392 0.015553 0.050690
7 0.008247 0.003047 0.014382
8 1.228708 0.913514 1.533121
9 0.307342 0.192380 0.426808

We can compare this comparison with bmb.interpret.comparisons.

summary_df = bmb.interpret.comparisons(
    fish_model,
    fish_idata,
    contrast={"livebait": [0, 1]},
    conditional=conditional
)
summary_df.head(10)
term estimate_type value camper persons child estimate lower_3.0% upper_97.0%
0 livebait diff (0, 1) 0 1 1 0.214363 0.144309 0.287735
1 livebait diff (0, 1) 0 1 2 0.053678 0.029533 0.077615
2 livebait diff (0, 1) 0 1 3 0.013558 0.006332 0.021971
3 livebait diff (0, 1) 0 1 4 0.003454 0.001132 0.006040
4 livebait diff (0, 1) 0 2 1 0.512709 0.369741 0.661034
5 livebait diff (0, 1) 0 2 2 0.128316 0.077068 0.181741
6 livebait diff (0, 1) 0 2 3 0.032392 0.015553 0.050690
7 livebait diff (0, 1) 0 2 4 0.008247 0.003047 0.014382
8 livebait diff (0, 1) 0 3 1 1.228708 0.913514 1.533121
9 livebait diff (0, 1) 0 3 2 0.307342 0.192380 0.426808

Albeit the other information in the summary_df, the columns estimate, lower_3.0%, upper_97.0% are identical.

Cross comparisons

Computing a cross-comparison is useful for when we want to compare contrasts when two (or more) predictors change at the same time. Cross-comparisons are currently not supported in the comparisons function, but we can use select_draws to compute them. For example, imagine we are interested in computing the cross-comparison between the two rows below.

summary_df.iloc[:2]
term estimate_type value camper persons child estimate lower_3.0% upper_97.0%
0 livebait diff (0, 1) 0 1 1 0.214363 0.144309 0.287735
1 livebait diff (0, 1) 0 1 2 0.053678 0.029533 0.077615

The cross-comparison amounts to first computing the comparison for row 0, given below, and can be verified by looking at the estimate in summary_df.

cond_10 = {
    "camper": 0,
    "persons": 1,
    "child": 1,
    "livebait": 0 
}

cond_11 = {
    "camper": 0,
    "persons": 1,
    "child": 1,
    "livebait": 1
}

draws_10 = select_draws(idata_grid, grid, cond_10, "count_mean")
draws_11 = select_draws(idata_grid, grid, cond_11, "count_mean")

(draws_11 - draws_10).mean(("chain", "draw")).item()
0.2143627093182434

Next, we need to compute the comparison for row 1.

cond_20 = {
    "camper": 0,
    "persons": 1,
    "child": 2,
    "livebait": 0
}

cond_21 = {
    "camper": 0,
    "persons": 1,
    "child": 2,
    "livebait": 1
}

draws_20 = select_draws(idata_grid, grid, cond_20, "count_mean")
draws_21 = select_draws(idata_grid, grid, cond_21, "count_mean")
(draws_21 - draws_20).mean(("chain", "draw")).item()
0.053678256991883604

Next, we compute the “first level” comparisons (diff_1 and diff_2). Subsequently, we compute the difference between these two differences to obtain the cross-comparison.

diff_1 = (draws_11 - draws_10)
diff_2 = (draws_21 - draws_20)

cross_comparison = (diff_2 - diff_1).mean(("chain", "draw")).item()
cross_comparison
-0.16068445232635978

To verify this is correct, we can check by performing the cross-comparison directly on the summary_df.

summary_df.iloc[1]["estimate"] - summary_df.iloc[0]["estimate"]
-0.16068445232635978

Summary

In this notebook, the interpret helper functions data_grid and select_draws were introduced and it was demonstrated how they can be used to compute pairwise grids of data and cross-comparisons. With these functions, it is left to the user to generate their grids of data and quantities of interest allowing for more flexibility and control over the type of data passed to model.predict and the quantities of interest computed.

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Tue Dec 05 2023

Python implementation: CPython
Python version       : 3.11.0
IPython version      : 8.13.2

numpy : 1.24.2
pandas: 2.1.0
bambi : 0.13.0.dev0
arviz : 0.16.1

Watermark: 2.3.1