Categorical Regression#

In this example, we will use the categorical family to model outcomes with more than two categories. The examples in this notebook were constructed by Tomás Capretto, and assembled into this example by Tyler James Burch (@tjburch on GitHub).

[1]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from matplotlib.lines import Line2D
[2]:
SEED = 1234
az.style.use("arviz-darkgrid")

When modeling binary outcomes with Bambi, the Bernoulli family is used. The multivariate generalization of the Bernoulli family is the Categorical family, and with it, we can model an arbitrary number of outcome categories.

Example with toy data#

To start, we will create a toy dataset with three classes.

[3]:
rng = np.random.default_rng(SEED)
x = np.hstack([rng.normal(m, s, size=50) for m, s in zip([-2.5, 0, 2.5], [1.2, 0.5, 1.2])])
y = np.array(["A"] * 50 + ["B"] * 50 + ["C"] * 50)

colors = ["C0"] * 50 + ["C1"] * 50 + ["C2"] * 50
plt.scatter(x, np.random.uniform(size=150), color=colors)
plt.xlabel("x")
plt.ylabel("y");
../_images/notebooks_categorical_regression_5_0.png

Here we have 3 classes, generated from three normal distributions: \(N(-2.5, 1.2)\), \(N(0, 0.5)\), and \(N(2.5, 1.2)\). Creating a model to fit these distributions,

[4]:
data = pd.DataFrame({"y": y, "x": x})
model = bmb.Model("y ~ x", data, family="categorical")
idata = model.fit()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [x, Intercept]
100.00% [4000/4000 00:04<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 5 seconds.

Note that we pass the family="categorical" argument to Bambi’s Model method in order to call the categorical family. Here, the response variable are strings (“A”, “B”, “C”), however they can also be pd.Categorical objects.

Next we will use posterior predictions to visualize the mean class probability across the \(x\) spectrum.

[6]:
x_new = np.linspace(-5, 5, num=200)
model.predict(idata, data=pd.DataFrame({"x": x_new}))
p = idata.posterior["y_mean"].sel(draw=slice(0, None, 10))
[8]:
x_new = np.linspace(-5, 5, num=200)
model.predict(idata, data=pd.DataFrame({"x": x_new}))
p = idata.posterior["y_mean"].sel(draw=slice(0, None, 10))

for j, g in enumerate("ABC"):
   plt.plot(x_new, p.sel({"y_mean_dim":g}).stack(samples=("chain", "draw")), color=f"C{j}", alpha=0.2)

plt.xlabel("x")
plt.ylabel("y");
../_images/notebooks_categorical_regression_10_0.png

Here, we can notice that the probability phases between classes from left to right. At all points across \(x\), sum of the class probabilities is 1, since in our generative model, it must be one of these three outcomes.

The iris dataset#

Next, we will look at the classic “iris” dataset, which contains samples from 3 different species of iris plants. Using properties of the plant, we will try to model its species.

[9]:
iris = sns.load_dataset("iris")
iris.head(3)
[9]:
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa

The dataset includes four different properties of the plants: it’s sepal length, sepal width, petal length, and petal width. There are 3 different class possibilities: setosa, versicolor, and virginica.

[10]:
sns.pairplot(iris, hue="species");
/home/tomas/anaconda3/envs/bambi/lib/python3.9/site-packages/seaborn/axisgrid.py:88: UserWarning: This figure was using constrained_layout, but that is incompatible with subplots_adjust and/or tight_layout; disabling constrained_layout.
  self._figure.tight_layout(*args, **kwargs)
../_images/notebooks_categorical_regression_15_1.png

We can see the three species have several distinct characteristics, which our linear model can capture to distinguish between them.

[11]:
model = bmb.Model(
    "species ~ sepal_length + sepal_width + petal_length + petal_width",
    iris,
    family="categorical",
)
idata = model.fit()
az.summary(idata)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [petal_width, petal_length, sepal_width, sepal_length, Intercept]
100.00% [4000/4000 00:51<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 52 seconds.
[11]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept[versicolor] -6.959 7.909 -21.094 7.953 0.234 0.183 1154.0 999.0 1.0
Intercept[virginica] -22.725 9.226 -40.377 -5.868 0.269 0.194 1177.0 1316.0 1.0
sepal_length[versicolor] 3.155 1.674 -0.030 6.223 0.055 0.040 914.0 1141.0 1.0
sepal_length[virginica] 2.346 1.749 -0.946 5.668 0.062 0.044 806.0 1043.0 1.0
sepal_width[versicolor] -4.730 1.932 -8.587 -1.461 0.060 0.043 1022.0 1275.0 1.0
sepal_width[virginica] -6.633 2.323 -10.787 -2.187 0.073 0.051 1016.0 954.0 1.0
petal_length[versicolor] 1.059 0.914 -0.709 2.765 0.031 0.023 874.0 964.0 1.0
petal_length[virginica] 4.006 1.028 1.961 5.814 0.032 0.023 1000.0 1157.0 1.0
petal_width[versicolor] 1.961 2.088 -1.634 6.187 0.075 0.053 784.0 982.0 1.0
petal_width[virginica] 9.105 2.309 4.933 13.555 0.074 0.052 961.0 858.0 1.0
[12]:
az.plot_trace(idata);

../_images/notebooks_categorical_regression_18_0.png

We can see that this has fit quite nicely. You’ll notice there are \(n-1\) parameters to fit, where \(n\) is the number of categories. In the minimal binary case, recall there’s only one parameter set, since it models probability \(p\) of being in a class, and probability \(1-p\) of being in the other class. Using the categorical distribution, this extends, so we have \(p_1\) for class 1, \(p_2\) for class 2, and \(1-(p_1+p_2)\) for the final class.

Using numerical and categorical predictors#

Next we will look at an example from chapter 8 of Alan Agresti’s Categorical Data Analysis, looking at the primary food choice for 64 alligators caught in Lake George, Florida. We will use their length (a continuous variable) and sex (a categorical variable) as predictors to model their food choice.

First, reproducing the dataset,

[13]:
length = [
    1.3, 1.32, 1.32, 1.4, 1.42, 1.42, 1.47, 1.47, 1.5, 1.52, 1.63, 1.65, 1.65, 1.65, 1.65,
    1.68, 1.7, 1.73, 1.78, 1.78, 1.8, 1.85, 1.93, 1.93, 1.98, 2.03, 2.03, 2.31, 2.36, 2.46,
    3.25, 3.28, 3.33, 3.56, 3.58, 3.66, 3.68, 3.71, 3.89, 1.24, 1.3, 1.45, 1.45, 1.55, 1.6,
    1.6, 1.65, 1.78, 1.78, 1.8, 1.88, 2.16, 2.26, 2.31, 2.36, 2.39, 2.41, 2.44, 2.56, 2.67,
    2.72, 2.79, 2.84
]
choice = [
    "I", "F", "F", "F", "I", "F", "I", "F", "I", "I", "I", "O", "O", "I", "F", "F",
    "I", "O", "F", "O", "F", "F", "I", "F", "I", "F", "F", "F", "F", "F", "O", "O",
    "F", "F", "F", "F", "O", "F", "F", "I", "I", "I", "O", "I", "I", "I", "F", "I",
    "O", "I", "I", "F", "F", "F", "F", "F", "F", "F", "O", "F", "I", "F", "F"
]

sex = ["Male"] * 32 + ["Female"] * 31
data = pd.DataFrame({"choice": choice, "length": length, "sex": sex})
data["choice"]  = pd.Categorical(
    data["choice"].map({"I": "Invertebrates", "F": "Fish", "O": "Other"}),
    ["Other", "Invertebrates", "Fish"],
    ordered=True
)
data.head(3)
[13]:
choice length sex
0 Invertebrates 1.30 Male
1 Fish 1.32 Male
2 Fish 1.32 Male

Next, constructing the model,

[14]:
model = bmb.Model("choice ~ length + sex", data, family="categorical")
idata = model.fit()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [sex, length, Intercept]
100.00% [4000/4000 00:19<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 20 seconds.

We can then look at how the food choices vary by length for both male and female alligators.

[15]:
new_length = np.linspace(1, 4)
new_data = pd.DataFrame({"length": np.tile(new_length, 2), "sex": ["Male"] * 50 + ["Female"] * 50})
model.predict(idata, data=new_data)
p = idata.posterior["choice_mean"]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
choices = ["Other", "Invertebrates", "Fish"]

for j, choice in enumerate(choices):
   males = p.sel({"choice_mean_dim":choice, "choice_obs":slice(0, 49)})
   females = p.sel({"choice_mean_dim":choice, "choice_obs":slice(50, 100)})
   axes[0].plot(new_length, males.mean(("chain", "draw")), color=f"C{j}", lw=2)
   axes[1].plot(new_length, females.mean(("chain", "draw")), color=f"C{j}", lw=2)
   az.plot_hdi(new_length, males, color=f"C{j}", ax=axes[0])
   az.plot_hdi(new_length, females, color=f"C{j}", ax=axes[1])

axes[0].set_title("Male")
axes[1].set_title("Female")

handles = [Line2D([], [], color=f"C{j}", label=choice) for j, choice in enumerate(choices)]
fig.subplots_adjust(left=0.05, right=0.975, bottom=0.075, top=0.85)

fig.legend(
   handles,
   choices,
   loc="center right",
   ncol=3,
   bbox_to_anchor=(0.99, 0.95),
   bbox_transform=fig.transFigure
);
/tmp/ipykernel_38053/2509692565.py:21: UserWarning: This figure was using constrained_layout, but that is incompatible with subplots_adjust and/or tight_layout; disabling constrained_layout.
  fig.subplots_adjust(left=0.05, right=0.975, bottom=0.075, top=0.85)
../_images/notebooks_categorical_regression_25_1.png

Here we can see that the larger male and female alligators are, the less of a taste they have for invertebrates, and far prefer fish. Additionally, males seem to have a higher propensity to consume “other” foods compared to females at any size. Of note, the posterior means predicted by Bambi contain information about all \(n\) categories (despite having only \(n-1\) coefficients), so we can directly construct this plot, rather than manually calculating \(1-(p_1+p_2)\) for the third class.

Last, we can make a posterior predictive plot,

[16]:
model.predict(idata, kind="pps")

ax = az.plot_ppc(idata)
ax.set_xticks([0.5, 1.5, 2.5])
ax.set_xticklabels(model.response.levels)
ax.set_xlabel("Choice");
ax.set_ylabel("Probability");
../_images/notebooks_categorical_regression_27_0.png

which depicts posterior predicted probability for each possible food choice for an alligator, which reinforces fish being the most likely food choice, followed by invertebrates.

References#

Agresti, A. (2013) Categorical Data Analysis. 3rd Edition, John Wiley & Sons Inc., Hoboken.