Categorical Regression

In this example, we will use the categorical family to model outcomes with more than two categories. The examples in this notebook were constructed by Tomás Capretto, and assembled into this example by Tyler James Burch (@tjburch on GitHub).

import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import warnings

from matplotlib.lines import Line2D

warnings.simplefilter(action="ignore", category=FutureWarning)
SEED = 1234
az.style.use("arviz-darkgrid")

When modeling binary outcomes with Bambi, the Bernoulli family is used. The multivariate generalization of the Bernoulli family is the Categorical family, and with it, we can model an arbitrary number of outcome categories.

Example with toy data

To start, we will create a toy dataset with three classes.

rng = np.random.default_rng(SEED)
x = np.hstack([rng.normal(m, s, size=50) for m, s in zip([-2.5, 0, 2.5], [1.2, 0.5, 1.2])])
y = np.array(["A"] * 50 + ["B"] * 50 + ["C"] * 50)

colors = ["C0"] * 50 + ["C1"] * 50 + ["C2"] * 50
plt.scatter(x, np.random.uniform(size=150), color=colors)
plt.xlabel("x")
plt.ylabel("y");

Here we have 3 classes, generated from three normal distributions: \(N(-2.5, 1.2)\), \(N(0, 0.5)\), and \(N(2.5, 1.2)\). Creating a model to fit these distributions,

data = pd.DataFrame({"y": y, "x": x})
model = bmb.Model("y ~ x", data, family="categorical")
idata = model.fit()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [Intercept, x]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 4 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Note that we pass the family="categorical" argument to Bambi’s Model method in order to call the categorical family. Here, the response variable are strings (“A”, “B”, “C”), however they can also be pd.Categorical objects.

Next we will use posterior predictions to visualize the mean class probability across the \(x\) spectrum.

x_new = np.linspace(-5, 5, num=200)
model.predict(idata, data=pd.DataFrame({"x": x_new}))
p = idata.posterior["p"].sel(draw=slice(0, None, 10))

for j, g in enumerate("ABC"):
   plt.plot(x_new, p.sel({"y_dim":g}).stack(samples=("chain", "draw")), color=f"C{j}", alpha=0.2)

plt.xlabel("x")
plt.ylabel("y");

Here, we can notice that the probability phases between classes from left to right. At all points across \(x\), sum of the class probabilities is 1, since in our generative model, it must be one of these three outcomes.

The iris dataset

Next, we will look at the classic “iris” dataset, which contains samples from 3 different species of iris plants. Using properties of the plant, we will try to model its species.

iris = sns.load_dataset("iris")
iris.head(3)
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa

The dataset includes four different properties of the plants: it’s sepal length, sepal width, petal length, and petal width. There are 3 different class possibilities: setosa, versicolor, and virginica.

sns.pairplot(iris, hue="species");
/home/tomas/anaconda3/envs/bambi-dev/lib/python3.11/site-packages/seaborn/axisgrid.py:123: UserWarning: The figure layout has changed to tight
  self._figure.tight_layout(*args, **kwargs)

We can see the three species have several distinct characteristics, which our linear model can capture to distinguish between them.

model = bmb.Model(
    "species ~ sepal_length + sepal_width + petal_length + petal_width", 
    iris, 
    family="categorical",
)
idata = model.fit()
az.summary(idata)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [Intercept, sepal_length, sepal_width, petal_length, petal_width]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 15 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept[versicolor] -7.054 7.937 -21.843 7.311 0.205 0.156 1498.0 1382.0 1.0
Intercept[virginica] -22.961 9.387 -41.590 -6.223 0.234 0.169 1603.0 1433.0 1.0
petal_length[versicolor] 1.064 0.922 -0.670 2.710 0.027 0.021 1172.0 1269.0 1.0
petal_length[virginica] 4.030 1.088 1.993 6.133 0.028 0.020 1543.0 1350.0 1.0
petal_width[versicolor] 1.926 2.040 -2.077 5.427 0.059 0.043 1195.0 1092.0 1.0
petal_width[virginica] 9.049 2.263 4.768 13.117 0.061 0.043 1365.0 1118.0 1.0
sepal_length[versicolor] 3.144 1.717 -0.104 6.211 0.053 0.040 1050.0 1238.0 1.0
sepal_length[virginica] 2.341 1.734 -0.738 5.722 0.054 0.040 1029.0 1234.0 1.0
sepal_width[versicolor] -4.674 1.865 -8.205 -1.402 0.054 0.038 1183.0 1337.0 1.0
sepal_width[virginica] -6.554 2.272 -10.763 -2.217 0.065 0.046 1214.0 1283.0 1.0
az.plot_trace(idata);

We can see that this has fit quite nicely. You’ll notice there are \(n-1\) parameters to fit, where \(n\) is the number of categories. In the minimal binary case, recall there’s only one parameter set, since it models probability \(p\) of being in a class, and probability \(1-p\) of being in the other class. Using the categorical distribution, this extends, so we have \(p_1\) for class 1, \(p_2\) for class 2, and \(1-(p_1+p_2)\) for the final class.

Using numerical and categorical predictors

Next we will look at an example from chapter 8 of Alan Agresti’s Categorical Data Analysis, looking at the primary food choice for 64 alligators caught in Lake George, Florida. We will use their length (a continuous variable) and sex (a categorical variable) as predictors to model their food choice.

First, reproducing the dataset,

length = [
    1.3, 1.32, 1.32, 1.4, 1.42, 1.42, 1.47, 1.47, 1.5, 1.52, 1.63, 1.65, 1.65, 1.65, 1.65,
    1.68, 1.7, 1.73, 1.78, 1.78, 1.8, 1.85, 1.93, 1.93, 1.98, 2.03, 2.03, 2.31, 2.36, 2.46,
    3.25, 3.28, 3.33, 3.56, 3.58, 3.66, 3.68, 3.71, 3.89, 1.24, 1.3, 1.45, 1.45, 1.55, 1.6, 
    1.6, 1.65, 1.78, 1.78, 1.8, 1.88, 2.16, 2.26, 2.31, 2.36, 2.39, 2.41, 2.44, 2.56, 2.67, 
    2.72, 2.79, 2.84
]
choice = [
    "I", "F", "F", "F", "I", "F", "I", "F", "I", "I", "I", "O", "O", "I", "F", "F", 
    "I", "O", "F", "O", "F", "F", "I", "F", "I", "F", "F", "F", "F", "F", "O", "O", 
    "F", "F", "F", "F", "O", "F", "F", "I", "I", "I", "O", "I", "I", "I", "F", "I", 
    "O", "I", "I", "F", "F", "F", "F", "F", "F", "F", "O", "F", "I", "F", "F"
]

sex = ["Male"] * 32 + ["Female"] * 31
data = pd.DataFrame({"choice": choice, "length": length, "sex": sex})
data["choice"]  = pd.Categorical(
    data["choice"].map({"I": "Invertebrates", "F": "Fish", "O": "Other"}), 
    ["Other", "Invertebrates", "Fish"], 
    ordered=True
)
data.head(3)
choice length sex
0 Invertebrates 1.30 Male
1 Fish 1.32 Male
2 Fish 1.32 Male

Next, constructing the model,

model = bmb.Model("choice ~ length + sex", data, family="categorical")
idata = model.fit()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [Intercept, length, sex]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 5 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Using bmb.interpret.plot_predictions, we can visualize how the probability of the different response levels varies conditional on a set of predictors. In the plot below, we visualize how the food choices vary by length for both male and female alligators. Note how estimate_dim (the response level) is mapped as the value to the group key.

bmb.interpret.plot_predictions(
    model,
    idata,
    ["length", "sex"],
    subplot_kwargs={"main": "length", "group": "estimate_dim", "panel": "sex"},
    fig_kwargs={"figsize": (12, 4)},
    legend=True
);
Default computed for conditional variable: length, sex

Here we can see that the larger male and female alligators are, the less of a taste they have for invertebrates, and far prefer fish. Additionally, males seem to have a higher propensity to consume “other” foods compared to females at any size. Of note, the posterior means predicted by Bambi contain information about all \(n\) categories (despite having only \(n-1\) coefficients), so we can directly construct this plot, rather than manually calculating \(1-(p_1+p_2)\) for the third class.

Last, we can make a posterior predictive plot,

model.predict(idata, kind="pps")

ax = az.plot_ppc(idata)
ax.set_xticks([0.5, 1.5, 2.5])
ax.set_xticklabels(model.response_component.term.levels)
ax.set_xlabel("Choice");
ax.set_ylabel("Probability");

which depicts posterior predicted probability for each possible food choice for an alligator, which reinforces fish being the most likely food choice, followed by invertebrates.

References

Agresti, A. (2013) Categorical Data Analysis. 3rd Edition, John Wiley & Sons Inc., Hoboken.

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sat May 25 2024

Python implementation: CPython
Python version       : 3.11.9
IPython version      : 8.24.0

bambi     : 0.13.1.dev37+g2a54df76.d20240525
pandas    : 2.2.2
arviz     : 0.18.0
matplotlib: 3.8.4
seaborn   : 0.13.2
numpy     : 1.26.4

Watermark: 2.4.3