Categorical Regression

In this example, we will use the categorical family to model outcomes with more than two categories. The examples in this notebook were constructed by Tomás Capretto, and assembled into this example by Tyler James Burch (@tjburch on GitHub).

import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import warnings

from matplotlib.lines import Line2D

warnings.simplefilter(action="ignore", category=FutureWarning)
SEED = 1234
az.style.use("arviz-darkgrid")

When modeling binary outcomes with Bambi, the Bernoulli family is used. The multivariate generalization of the Bernoulli family is the Categorical family, and with it, we can model an arbitrary number of outcome categories.

Example with toy data

To start, we will create a toy dataset with three classes.

rng = np.random.default_rng(SEED)
x = np.hstack([rng.normal(m, s, size=50) for m, s in zip([-2.5, 0, 2.5], [1.2, 0.5, 1.2])])
y = np.array(["A"] * 50 + ["B"] * 50 + ["C"] * 50)

colors = ["C0"] * 50 + ["C1"] * 50 + ["C2"] * 50
plt.scatter(x, np.random.uniform(size=150), color=colors)
plt.xlabel("x")
plt.ylabel("y");

Here we have 3 classes, generated from three normal distributions: \(N(-2.5, 1.2)\), \(N(0, 0.5)\), and \(N(2.5, 1.2)\). Creating a model to fit these distributions,

data = pd.DataFrame({"y": y, "x": x})
model = bmb.Model("y ~ x", data, family="categorical")
idata = model.fit()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Intercept, x]
100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.

Note that we pass the family="categorical" argument to Bambi’s Model method in order to call the categorical family. Here, the response variable are strings (“A”, “B”, “C”), however they can also be pd.Categorical objects.

Next we will use posterior predictions to visualize the mean class probability across the \(x\) spectrum.

x_new = np.linspace(-5, 5, num=200)
model.predict(idata, data=pd.DataFrame({"x": x_new}))
p = idata.posterior["y_mean"].sel(draw=slice(0, None, 10))
x_new = np.linspace(-5, 5, num=200)
model.predict(idata, data=pd.DataFrame({"x": x_new}))
p = idata.posterior["y_mean"].sel(draw=slice(0, None, 10))

for j, g in enumerate("ABC"):
   plt.plot(x_new, p.sel({"y_dim":g}).stack(samples=("chain", "draw")), color=f"C{j}", alpha=0.2)

plt.xlabel("x")
plt.ylabel("y");

Here, we can notice that the probability phases between classes from left to right. At all points across \(x\), sum of the class probabilities is 1, since in our generative model, it must be one of these three outcomes.

The iris dataset

Next, we will look at the classic “iris” dataset, which contains samples from 3 different species of iris plants. Using properties of the plant, we will try to model its species.

iris = sns.load_dataset("iris")
iris.head(3)
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa

The dataset includes four different properties of the plants: it’s sepal length, sepal width, petal length, and petal width. There are 3 different class possibilities: setosa, versicolor, and virginica.

sns.pairplot(iris, hue="species");
/Users/gabestechschulte/miniforge3/envs/bambinos/lib/python3.11/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
  self._figure.tight_layout(*args, **kwargs)
/Users/gabestechschulte/miniforge3/envs/bambinos/lib/python3.11/site-packages/seaborn/axisgrid.py:208: UserWarning: This figure was using a layout engine that is incompatible with subplots_adjust and/or tight_layout; not calling subplots_adjust.
  self._figure.subplots_adjust(right=right)
/Users/gabestechschulte/miniforge3/envs/bambinos/lib/python3.11/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
  self._figure.tight_layout(*args, **kwargs)

We can see the three species have several distinct characteristics, which our linear model can capture to distinguish between them.

model = bmb.Model(
    "species ~ sepal_length + sepal_width + petal_length + petal_width", 
    iris, 
    family="categorical",
)
idata = model.fit()
az.summary(idata)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Intercept, sepal_length, sepal_width, petal_length, petal_width]
100.00% [8000/8000 00:06<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept[versicolor] -6.723 7.921 -20.761 8.601 0.157 0.120 2571.0 2554.0 1.0
Intercept[virginica] -22.730 9.630 -42.231 -5.452 0.176 0.129 3008.0 2837.0 1.0
sepal_length[versicolor] 3.096 1.732 -0.219 6.223 0.038 0.029 2060.0 2088.0 1.0
sepal_length[virginica] 2.341 1.786 -0.892 5.801 0.040 0.031 2040.0 2072.0 1.0
sepal_width[versicolor] -4.712 1.981 -8.639 -1.108 0.044 0.032 2083.0 2260.0 1.0
sepal_width[virginica] -6.638 2.409 -11.017 -1.936 0.054 0.040 2031.0 2415.0 1.0
petal_length[versicolor] 1.083 0.905 -0.657 2.797 0.018 0.013 2464.0 2559.0 1.0
petal_length[virginica] 4.010 1.053 1.940 5.884 0.020 0.014 2797.0 2458.0 1.0
petal_width[versicolor] 1.903 2.049 -1.881 5.854 0.044 0.032 2153.0 2357.0 1.0
petal_width[virginica] 9.098 2.310 4.541 13.170 0.049 0.037 2230.0 2367.0 1.0
az.plot_trace(idata);

We can see that this has fit quite nicely. You’ll notice there are \(n-1\) parameters to fit, where \(n\) is the number of categories. In the minimal binary case, recall there’s only one parameter set, since it models probability \(p\) of being in a class, and probability \(1-p\) of being in the other class. Using the categorical distribution, this extends, so we have \(p_1\) for class 1, \(p_2\) for class 2, and \(1-(p_1+p_2)\) for the final class.

Using numerical and categorical predictors

Next we will look at an example from chapter 8 of Alan Agresti’s Categorical Data Analysis, looking at the primary food choice for 64 alligators caught in Lake George, Florida. We will use their length (a continuous variable) and sex (a categorical variable) as predictors to model their food choice.

First, reproducing the dataset,

length = [
    1.3, 1.32, 1.32, 1.4, 1.42, 1.42, 1.47, 1.47, 1.5, 1.52, 1.63, 1.65, 1.65, 1.65, 1.65,
    1.68, 1.7, 1.73, 1.78, 1.78, 1.8, 1.85, 1.93, 1.93, 1.98, 2.03, 2.03, 2.31, 2.36, 2.46,
    3.25, 3.28, 3.33, 3.56, 3.58, 3.66, 3.68, 3.71, 3.89, 1.24, 1.3, 1.45, 1.45, 1.55, 1.6, 
    1.6, 1.65, 1.78, 1.78, 1.8, 1.88, 2.16, 2.26, 2.31, 2.36, 2.39, 2.41, 2.44, 2.56, 2.67, 
    2.72, 2.79, 2.84
]
choice = [
    "I", "F", "F", "F", "I", "F", "I", "F", "I", "I", "I", "O", "O", "I", "F", "F", 
    "I", "O", "F", "O", "F", "F", "I", "F", "I", "F", "F", "F", "F", "F", "O", "O", 
    "F", "F", "F", "F", "O", "F", "F", "I", "I", "I", "O", "I", "I", "I", "F", "I", 
    "O", "I", "I", "F", "F", "F", "F", "F", "F", "F", "O", "F", "I", "F", "F"
]

sex = ["Male"] * 32 + ["Female"] * 31
data = pd.DataFrame({"choice": choice, "length": length, "sex": sex})
data["choice"]  = pd.Categorical(
    data["choice"].map({"I": "Invertebrates", "F": "Fish", "O": "Other"}), 
    ["Other", "Invertebrates", "Fish"], 
    ordered=True
)
data.head(3)
choice length sex
0 Invertebrates 1.30 Male
1 Fish 1.32 Male
2 Fish 1.32 Male

Next, constructing the model,

model = bmb.Model("choice ~ length + sex", data, family="categorical")
idata = model.fit()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Intercept, length, sex]
100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.

Using bmb.interpret.plot_predictions, we can visualize how the probability of the different response levels varies conditional on a set of predictors. In the plot below, we visualize how the food choices vary by length for both male and female alligators. Note how estimate_dim (the response level) is mapped as the value to the group key.

bmb.interpret.plot_predictions(
    model,
    idata,
    ["length", "sex"],
    subplot_kwargs={"main": "length", "group": "estimate_dim", "panel": "sex"},
    fig_kwargs={"figsize": (12, 4)},
    legend=True
);

Here we can see that the larger male and female alligators are, the less of a taste they have for invertebrates, and far prefer fish. Additionally, males seem to have a higher propensity to consume “other” foods compared to females at any size. Of note, the posterior means predicted by Bambi contain information about all \(n\) categories (despite having only \(n-1\) coefficients), so we can directly construct this plot, rather than manually calculating \(1-(p_1+p_2)\) for the third class.

Last, we can make a posterior predictive plot,

model.predict(idata, kind="pps")

ax = az.plot_ppc(idata)
ax.set_xticks([0.5, 1.5, 2.5])
ax.set_xticklabels(model.response_component.response_term.levels)
ax.set_xlabel("Choice");
ax.set_ylabel("Probability");

which depicts posterior predicted probability for each possible food choice for an alligator, which reinforces fish being the most likely food choice, followed by invertebrates.

References

Agresti, A. (2013) Categorical Data Analysis. 3rd Edition, John Wiley & Sons Inc., Hoboken.

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Tue Oct 10 2023

Python implementation: CPython
Python version       : 3.11.0
IPython version      : 8.13.2

pandas    : 2.1.0
matplotlib: 3.7.1
numpy     : 1.24.2
arviz     : 0.16.1
seaborn   : 0.12.2
bambi     : 0.13.0.dev0

Watermark: 2.3.1