import warnings
import arviz as az
import numpy as np
import pandas as pd
import bambi as bmb
from bambi.interpret.helpers import data_grid, select_draws
='ignore', category=FutureWarning) warnings.simplefilter(action
Interpret Advanced Usage
The interpret
module provides a set of helper functions to aid the user in more advanced and complex analysis not covered within the comparisons
, predictions
, and slopes
functions. These helper functions are data_grid
and select_draws
. The data_grid
can be used to create a pairwise grid of data points for the user to pass to model.predict
. Subsequently, select_draws
is used to select the draws from the posterior (or posterior predictive) group of the InferenceData object returned by the predict method that correspond to the data points that “produced” that draw.
With access to the appropriately indexed draws, and data used to generate those draws, it enables for more complex analysis such as cross-comparisons and the choice of which model parameter to compute a quantity of interest for; among others. Additionally, the user has more control over the data passed to model.predict
. Below, it will be demonstrated how to use these helper functions. First, to reproduce the results from the standard interpret
API, and then to compute cross-comparisons.
Zero Inflated Poisson
We will adopt the zero inflated Poisson (ZIP) model from the comparisons documentation to demonstrate the helper functions introduced above.
The ZIP model will be used to predict how many fish are caught by visitors at a state park using survey data. Many visitors catch zero fish, either because they did not fish at all, or because they were unlucky. We would like to explicitly model this bimodal behavior (zero versus non-zero) using a Zero Inflated Poisson model, and to compare how different inputs of interest \(w\) and other covariate values \(c\) are associated with the number of fish caught. The dataset contains data on 250 groups that went to a state park to fish. Each group was questioned about how many fish they caught (count
), how many children were in the group (child
), how many people were in the group (persons
), if they used a live bait and whether or not they brought a camper to the park (camper
).
= pd.read_csv("https://stats.idre.ucla.edu/stat/data/fish.csv")
fish_data = ["count", "livebait", "camper", "persons", "child"]
cols = fish_data[cols]
fish_data "child"] = fish_data["child"].astype(np.int8)
fish_data["persons"] = fish_data["persons"].astype(np.int8)
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"]) fish_data[
= bmb.Model(
fish_model "count ~ livebait + camper + persons + child",
fish_data, ='zero_inflated_poisson'
family
)
= fish_model.fit(random_seed=1234) fish_idata
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [psi, Intercept, livebait, camper, persons, child]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 8 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
Create a grid of data
data_grid
allows you to create a pairwise grid, also known as a cross-join or cartesian product, of data using the covariates passed to the conditional
and the optional variable
parameter. Covariates not passed to conditional
, but are terms in the Bambi model, are set to typical values (e.g., mean or mode). If you are coming from R, this function is partially inspired from the data_grid function in {modelr}.
There are two ways to create a pairwise grid of data:
- user-provided values are passed as a dictionary to
conditional
where the keys are the names of the covariates and the values are the values to use in the grid. - a list of covariates where the elements are the names of the covariates to use in the grid. As only the names of the covariates were passed, default values are computed to construct the grid.
Any unspecified covariates, i.e., covariates not passed to conditional
but are terms in the Bambi model, are set to their “typical” values such as mean or mode depending on the data type of the covariate.
User-provided values
To construct a pairwise grid of data for specific covariate values, pass a dictionary to conditional
. The values of the dictionary can be of type int
, float
, list
, or np.ndarray
.
= {
conditional "camper": np.array([0, 1]),
"persons": np.arange(1, 5, 1),
"child": np.array([1, 2, 3]),
}= data_grid(fish_model, conditional)
user_passed_grid "camper == 0") user_passed_grid.query(
Default computed for unspecified variable: livebait
camper | persons | child | livebait | |
---|---|---|---|---|
0 | 0 | 1 | 1 | 1 |
1 | 0 | 1 | 2 | 1 |
2 | 0 | 1 | 3 | 1 |
3 | 0 | 2 | 1 | 1 |
4 | 0 | 2 | 2 | 1 |
5 | 0 | 2 | 3 | 1 |
6 | 0 | 3 | 1 | 1 |
7 | 0 | 3 | 2 | 1 |
8 | 0 | 3 | 3 | 1 |
9 | 0 | 4 | 1 | 1 |
10 | 0 | 4 | 2 | 1 |
11 | 0 | 4 | 3 | 1 |
Subsetting by camper = 0
, it can be seen that a combination of all possible pairs of values from the dictionary (including the unspecified variable livebait
) results in a dataframe containing every possible combination of values from the original sets. livebait
has been set to 1 as this is the mode of the unspecified categorical variable.
Default values
Alternatively, a list of covariates can be passed to conditional
where the elements are the names of the covariates to use in the grid. By doing this, you are telling interpret
to compute default values for these covariates. The psuedocode below outlines the logic and functions used to compute these default values:
if is_numeric_dtype(x) or is_float_dtype(x):
= np.linspace(np.min(x), np.max(x), 50)
values
elif is_integer_dtype(x):
= np.quantile(x, np.linspace(0, 1, 5))
values
elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
= np.unique(x) values
= ["camper", "persons", "child"]
conditional = data_grid(fish_model, conditional)
default_grid
default_grid.shape, user_passed_grid.shape
Default computed for conditional variable: camper, persons, child
Default computed for unspecified variable: livebait
((32, 4), (24, 4))
Notice how the resulting length is different between the user passed and default grid. This is due to the fact that values for child
range from 0 to 3 for the default grid.
"camper == 0") default_grid.query(
camper | persons | child | livebait | |
---|---|---|---|---|
0 | 0 | 1 | 0 | 1 |
1 | 0 | 1 | 1 | 1 |
2 | 0 | 1 | 2 | 1 |
3 | 0 | 1 | 3 | 1 |
4 | 0 | 2 | 0 | 1 |
5 | 0 | 2 | 1 | 1 |
6 | 0 | 2 | 2 | 1 |
7 | 0 | 2 | 3 | 1 |
8 | 0 | 3 | 0 | 1 |
9 | 0 | 3 | 1 | 1 |
10 | 0 | 3 | 2 | 1 |
11 | 0 | 3 | 3 | 1 |
12 | 0 | 4 | 0 | 1 |
13 | 0 | 4 | 1 | 1 |
14 | 0 | 4 | 2 | 1 |
15 | 0 | 4 | 3 | 1 |
Compute comparisons
To use data_grid
to help generate data in computing comparisons or slopes, additional data is passed to the optional variable
parameter. The name variable
is an abstraction for the comparisons parameter contrast
and slopes parameter wrt
. If you have used any of the interpret
functions, these parameter names should be familiar and the use of data_grid
should be analogous to comparisons
, predictions
, and slopes
.
variable
can also be passed user-provided data (as a dictionary), or a string indicating the name of the covariate of interest. If the latter, a default value will be computed. Additionally, if an argument is passed for variable
, then the effect_type
needs to be passed. This is because for comparisons
and slopes
an epsilon value eps
needs to be determined to compute the centered and finite difference, respectively. You can also pass a value for eps
as a kwarg.
= {
conditional "camper": np.array([0, 1]),
"persons": np.arange(1, 5, 1),
"child": np.array([1, 2, 3, 4])
}= "livebait"
variable
= data_grid(fish_model, conditional, variable, effect_type="comparisons") grid
Default computed for contrast variable: livebait
= fish_model.predict(fish_idata, data=grid, inplace=False) idata_grid
Select draws conditional on data
The second helper function to aid in more advanced analysis is select_draws
. This is a function that selects the posterior or posterior predictive draws from the ArviZ InferenceData object returned by model.predict
given a conditional dictionary. The conditional dictionary represents the values that correspond to that draw.
For example, if we wanted to select posterior draws where livebait = [0, 1]
, then all we need to do is pass a dictionary where the key is the name of the covariate and the value is the value that we want to condition on (or select). The resulting InferenceData object will contain the draws that correspond to the data points where livebait = [0, 1]
. Additionally, you must pass the InferenceData object returned by model.predict
, the data used to generate the predictions, and the name of the data variable data_var
you would like to select from the InferenceData posterior group. If you specified to return the posterior predictive samples by passing model.predict(..., kind="pps")
, you can use this group instead of the posterior group by passing group="posterior_predictive"
.
Below, it is demonstrated how to compute comparisons for count_mean
for the contrast livebait = [0, 1]
using the posterior draws.
idata_grid
-
<xarray.Dataset> Size: 1MB Dimensions: (chain: 2, draw: 1000, camper_dim: 1, livebait_dim: 1, __obs__: 64) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * camper_dim (camper_dim) <U1 4B '1' * livebait_dim (livebait_dim) <U1 4B '1' * __obs__ (__obs__) int64 512B 0 1 2 3 4 5 6 7 ... 57 58 59 60 61 62 63 Data variables: Intercept (chain, draw) float64 16kB -2.319 -2.441 ... -2.705 -2.571 camper (chain, draw, camper_dim) float64 16kB 0.7341 0.6539 ... 0.763 child (chain, draw) float64 16kB -1.232 -1.251 ... -1.437 -1.344 livebait (chain, draw, livebait_dim) float64 16kB 1.509 1.634 ... 1.7 persons (chain, draw) float64 16kB 0.8448 0.8763 ... 0.903 0.8508 psi (chain, draw) float64 16kB 0.6017 0.6008 ... 0.6933 0.5425 mu (chain, draw, __obs__) float64 1MB 0.06677 0.3021 ... 0.1248 Attributes: created_at: 2024-05-26T21:58:43.655562+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.15.0+23.g19be124e sampling_time: 8.18968677520752 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.13.1.dev39+gb7d6a6cb
-
<xarray.Dataset> Size: 252kB Dimensions: (chain: 2, draw: 1000) Coordinates: * chain (chain) int64 16B 0 1 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/17) acceptance_rate (chain, draw) float64 16kB 0.4852 0.4691 ... 0.4878 diverging (chain, draw) bool 2kB False False ... False False energy (chain, draw) float64 16kB 756.7 754.8 ... 755.2 energy_error (chain, draw) float64 16kB 0.1849 -0.1576 ... 0.6626 index_in_trajectory (chain, draw) int64 16kB 6 -2 -2 5 3 ... 5 -2 -5 -5 6 largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan ... ... process_time_diff (chain, draw) float64 16kB 0.004515 ... 0.002555 reached_max_treedepth (chain, draw) bool 2kB False False ... False False smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan step_size (chain, draw) float64 16kB 0.4493 0.4493 ... 0.2491 step_size_bar (chain, draw) float64 16kB 0.4092 0.4092 ... 0.3806 tree_depth (chain, draw) int64 16kB 4 3 3 4 3 3 ... 4 3 3 3 4 3 Attributes: created_at: 2024-05-26T21:58:43.673049+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.15.0+23.g19be124e sampling_time: 8.18968677520752 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.13.1.dev39+gb7d6a6cb
-
<xarray.Dataset> Size: 4kB Dimensions: (__obs__: 250) Coordinates: * __obs__ (__obs__) int64 2kB 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249 Data variables: count (__obs__) int64 2kB 0 0 0 0 1 0 0 0 0 1 0 ... 4 1 1 0 1 0 0 0 0 0 0 Attributes: created_at: 2024-05-26T21:58:43.678881+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.15.0+23.g19be124e modeling_interface: bambi modeling_interface_version: 0.13.1.dev39+gb7d6a6cb
= select_draws(idata_grid, grid, {"livebait": 0}, "mu") draw_1
= select_draws(idata_grid, grid, {"livebait": 0}, "mu")
draw_1 = select_draws(idata_grid, grid, {"livebait": 1}, "mu")
draw_2
= (draw_2 - draw_1).mean(("chain", "draw"))
comparison_mean = az.hdi(draw_2 - draw_1)
comparison_hdi
= pd.DataFrame(
comparison_df
{"mean": comparison_mean.values,
"hdi_low": comparison_hdi.sel(hdi="lower")["mu"].values,
"hdi_high": comparison_hdi.sel(hdi="higher")["mu"].values,
}
)10) comparison_df.head(
mean | hdi_low | hdi_high | |
---|---|---|---|
0 | 0.216038 | 0.150384 | 0.294501 |
1 | 0.054384 | 0.029894 | 0.077987 |
2 | 0.013811 | 0.005794 | 0.021496 |
3 | 0.003539 | 0.001142 | 0.006083 |
4 | 0.515839 | 0.375190 | 0.670931 |
5 | 0.129765 | 0.078764 | 0.182346 |
6 | 0.032932 | 0.014961 | 0.050478 |
7 | 0.008432 | 0.002665 | 0.014138 |
8 | 1.234183 | 0.926056 | 1.553555 |
9 | 0.310259 | 0.195251 | 0.425951 |
We can compare this comparison with bmb.interpret.comparisons
.
= bmb.interpret.comparisons(
summary_df
fish_model,
fish_idata,={"livebait": [0, 1]},
contrast=conditional
conditional
)10) summary_df.head(
term | estimate_type | value | camper | persons | child | estimate | lower_3.0% | upper_97.0% | |
---|---|---|---|---|---|---|---|---|---|
0 | livebait | diff | (0, 1) | 0 | 1 | 1 | 0.216038 | 0.150384 | 0.294501 |
1 | livebait | diff | (0, 1) | 0 | 1 | 2 | 0.054384 | 0.029894 | 0.077987 |
2 | livebait | diff | (0, 1) | 0 | 1 | 3 | 0.013811 | 0.005794 | 0.021496 |
3 | livebait | diff | (0, 1) | 0 | 1 | 4 | 0.003539 | 0.001142 | 0.006083 |
4 | livebait | diff | (0, 1) | 0 | 2 | 1 | 0.515839 | 0.375190 | 0.670931 |
5 | livebait | diff | (0, 1) | 0 | 2 | 2 | 0.129765 | 0.078764 | 0.182346 |
6 | livebait | diff | (0, 1) | 0 | 2 | 3 | 0.032932 | 0.014961 | 0.050478 |
7 | livebait | diff | (0, 1) | 0 | 2 | 4 | 0.008432 | 0.002665 | 0.014138 |
8 | livebait | diff | (0, 1) | 0 | 3 | 1 | 1.234183 | 0.926056 | 1.553555 |
9 | livebait | diff | (0, 1) | 0 | 3 | 2 | 0.310259 | 0.195251 | 0.425951 |
Albeit the other information in the summary_df
, the columns estimate
, lower_3.0%
, upper_97.0%
are identical.
Cross comparisons
Computing a cross-comparison is useful for when we want to compare contrasts when two (or more) predictors change at the same time. Cross-comparisons are currently not supported in the comparisons
function, but we can use select_draws
to compute them. For example, imagine we are interested in computing the cross-comparison between the two rows below.
2] summary_df.iloc[:
term | estimate_type | value | camper | persons | child | estimate | lower_3.0% | upper_97.0% | |
---|---|---|---|---|---|---|---|---|---|
0 | livebait | diff | (0, 1) | 0 | 1 | 1 | 0.216038 | 0.150384 | 0.294501 |
1 | livebait | diff | (0, 1) | 0 | 1 | 2 | 0.054384 | 0.029894 | 0.077987 |
The cross-comparison amounts to first computing the comparison for row 0, given below, and can be verified by looking at the estimate in summary_df
.
= {
cond_10 "camper": 0,
"persons": 1,
"child": 1,
"livebait": 0
}
= {
cond_11 "camper": 0,
"persons": 1,
"child": 1,
"livebait": 1
}
= select_draws(idata_grid, grid, cond_10, "mu")
draws_10 = select_draws(idata_grid, grid, cond_11, "mu")
draws_11
- draws_10).mean(("chain", "draw")).item() (draws_11
0.21603768204145285
Next, we need to compute the comparison for row 1.
= {
cond_20 "camper": 0,
"persons": 1,
"child": 2,
"livebait": 0
}
= {
cond_21 "camper": 0,
"persons": 1,
"child": 2,
"livebait": 1
}
= select_draws(idata_grid, grid, cond_20, "mu")
draws_20 = select_draws(idata_grid, grid, cond_21, "mu") draws_21
- draws_20).mean(("chain", "draw")).item() (draws_21
0.0543836038546379
Next, we compute the “first level” comparisons (diff_1
and diff_2
). Subsequently, we compute the difference between these two differences to obtain the cross-comparison.
= (draws_11 - draws_10)
diff_1 = (draws_21 - draws_20)
diff_2
= (diff_2 - diff_1).mean(("chain", "draw")).item()
cross_comparison cross_comparison
-0.16165407818681496
To verify this is correct, we can check by performing the cross-comparison directly on the summary_df
.
1]["estimate"] - summary_df.iloc[0]["estimate"] summary_df.iloc[
-0.16165407818681496
Summary
In this notebook, the interpret helper functions data_grid
and select_draws
were introduced and it was demonstrated how they can be used to compute pairwise grids of data and cross-comparisons. With these functions, it is left to the user to generate their grids of data and quantities of interest allowing for more flexibility and control over the type of data passed to model.predict
and the quantities of interest computed.
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun May 26 2024
Python implementation: CPython
Python version : 3.11.9
IPython version : 8.24.0
numpy : 1.26.4
arviz : 0.18.0
bambi : 0.13.1.dev39+gb7d6a6cb
pandas: 2.2.2
Watermark: 2.4.3