import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Regression splines (Cherry blossom example)
This example shows how to specify and fit a spline regression in Bambi. This example is based on this example from the PyMC docs.
"arviz-darkgrid")
az.style.use(= 7355608 SEED
Load Cherry Blossom data
Richard McElreath popularized the Cherry Blossom dataset in the second edition of his excellent book Statistical Rethinking. This data represents the day in the year when the first bloom is observed for Japanese cherry blossoms between years 801 and 2015. In his book, Richard McElreath uses this dataset to introduce Basis Splines, or B-Splines in short.
Here we use Bambi to fit a linear model using B-Splines with the Cherry Blossom data. This dataset can be loaded with Bambi as follows:
= bmb.load_data("cherry_blossoms")
data data
year | doy | temp | temp_upper | temp_lower | |
---|---|---|---|---|---|
0 | 801 | NaN | NaN | NaN | NaN |
1 | 802 | NaN | NaN | NaN | NaN |
2 | 803 | NaN | NaN | NaN | NaN |
3 | 804 | NaN | NaN | NaN | NaN |
4 | 805 | NaN | NaN | NaN | NaN |
... | ... | ... | ... | ... | ... |
1210 | 2011 | 99.0 | NaN | NaN | NaN |
1211 | 2012 | 101.0 | NaN | NaN | NaN |
1212 | 2013 | 93.0 | NaN | NaN | NaN |
1213 | 2014 | 94.0 | NaN | NaN | NaN |
1214 | 2015 | 93.0 | NaN | NaN | NaN |
1215 rows × 5 columns
The variable we are interested in modeling is "doy"
, which stands for Day of Year. Also notice this variable contains several missing value which are discarded next.
= data.dropna(subset=["doy"]).reset_index(drop=True)
data data.shape
(827, 5)
Explore the data
Let’s get started by creating a scatterplot to explore the values of "doy"
for each year in the dataset.
# We create a function because this plot is going to be used again later
def plot_scatter(data, figsize=(10, 6)):
= plt.subplots(figsize=figsize)
_, ax "year"], data["doy"], alpha=0.4, s=30)
ax.scatter(data["Day of the first bloom per year")
ax.set_title("Year")
ax.set_xlabel("Days of the first bloom")
ax.set_ylabel(return ax
; plot_scatter(data)
We can observe the day of the first bloom ranges between 85 and 125 approximately, which correspond to late March and early May respectively. On average, the first bloom occurs on the 105th day of the year, which is middle April.
Determine knots
The spline will have 15 knots. These knots are the boundaries of the basis functions. These knots split the range of the "year"
variable into 16 contiguous sections. The basis functions make up a piecewise continuous polynomial, and so they are enforced to meet at the knots. We use the default degree for each piecewise polynomial, which is 3. The result is known as a cubic spline.
Because of using quantiles and not having observations for all the years in the time window under study, the knots are distributed unevenly over the range of "year"
in such a way that the same proportion of values fall between each section.
= 15
num_knots = np.quantile(data["year"], np.linspace(0, 1, num_knots)) knots
def plot_knots(knots, ax):
for knot in knots:
="0.1", alpha=0.4)
ax.axvline(knot, colorreturn ax
= plot_scatter(data)
ax ; plot_knots(knots, ax)
The previous chart makes it easy to see the knots, represented by the vertical lines, are spaced unevenly over the years.
The model
The B-spline model we are about to create is simply a linear regression model with synthetic predictor variables. These predictors are the basis functions that are derived from the original year
predictor.
In math notation, we usa a \(\text{Normal}\) distribution for the conditional distribution of \(Y\) when \(X = x_i\), i.e. \(Y_i\), the distribution of the day of the first bloom in a given year.
\[ Y_i \sim \text{Normal}(\mu_i, \sigma) \]
So far, this looks like a regular linear regression model. The next line is where the spline comes into play:
\[ \mu_i = \alpha + \sum_{k=1}^K{w_kB_{k, i}} \]
The line above tells that for each observation \(i\), the mean is influenced by all the basis functions (going from \(k=1\) to \(k=K\)), plus an intercept \(\alpha\). The \(w_k\) values in the summation are the regression coefficients of each of the basis functions, and the \(B_k\) are the values of the basis functions.
Finally, we will be using the following priors
\[ \begin{aligned} \alpha & \sim \text{Normal}(100, 10) \\ w_j & \sim \text{Normal}(0, 10)\\ \sigma & \sim \text{Exponential(1)} \end{aligned} \]
where \(j\) indexes each of the contiguous sections given by the knots
# We only pass the internal knots to the `bs()` function.
= knots[1:-1]
iknots
# Define dictionary of priors
= {
priors "Intercept": bmb.Prior("Normal", mu=100, sigma=10),
"common": bmb.Prior("Normal", mu=0, sigma=10),
"sigma": bmb.Prior("Exponential", lam=1)
}
# Define model
# The intercept=True means the basis also spans the intercept, as originally done in the book example.
= bmb.Model("doy ~ bs(year, knots=iknots, intercept=True)", data, priors=priors)
model model
Formula: doy ~ bs(year, knots=iknots, intercept=True)
Family: gaussian
Link: mu = identity
Observations: 827
Priors:
target = mu
Common-level effects
Intercept ~ Normal(mu: 100.0, sigma: 10.0)
bs(year, knots=iknots, intercept=True) ~ Normal(mu: 0.0, sigma: 10.0)
Auxiliary parameters
sigma ~ Exponential(lam: 1.0)
Let’s create a function to plot each of the basis functions in the model.
def plot_spline_basis(basis, year, figsize=(10, 6)):
= (
df
pd.DataFrame(basis)=year)
.assign(year"year", var_name="basis_idx", value_name="value")
.melt(
)
= plt.subplots(figsize=figsize)
_, ax
for idx in df.basis_idx.unique():
= df[df.basis_idx == idx]
d "year"], d["value"])
ax.plot(d[
return ax
Below, we create a chart to visualize the b-spline basis. The overlap between the functions means that, at any given point in time, the regression function is influenced by more than one basis function. For example, if we look at the year 1200, we can see the regression line is going to be influenced mostly by the violet and brown functions, and to a lesser extent by the green and cyan ones. In summary, this is what enables us to capture local patterns in a smooth fashion.
= model.components["mu"].design.common["bs(year, knots=iknots, intercept=True)"]
B = plot_spline_basis(B, data["year"].values)
ax ; plot_knots(knots, ax)
Fit model
Now we fit the model. In Bambi, it is as easy as calling the .fit()
method on the Model
instance.
# The seed is to make results reproducible
= model.fit(random_seed=SEED, idata_kwargs={"log_likelihood": True}) idata
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [sigma, Intercept, bs(year, knots=iknots, intercept=True)]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 10 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
Analisys of the results
It is always good to use az.summary()
to verify parameter estimates as well as effective sample sizes and R hat values. In this case, the main goal is not to interpret the coefficients of the basis spline, but analyze the ess
and r_hat
diagnostics. In first place, effective sample sizes don’t look impressively high. Most of them are between 300 and 700, which is low compared to the 2000 draws obtained. The only exception is the residual standard deviation sigma
. Finally, the r_hat
diagnostic is not always 1 for all the parameters, indicating there may be some issues with the mix of the chains.
az.summary(idata)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 103.644 | 2.419 | 99.691 | 108.377 | 0.115 | 0.082 | 446.0 | 723.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[0] | -3.262 | 3.909 | -10.682 | 3.781 | 0.136 | 0.097 | 824.0 | 1250.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[1] | -1.213 | 4.016 | -8.644 | 6.428 | 0.132 | 0.094 | 918.0 | 1337.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[2] | -1.296 | 3.574 | -7.878 | 5.467 | 0.127 | 0.090 | 795.0 | 1190.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[3] | 4.549 | 2.968 | -1.297 | 9.678 | 0.122 | 0.086 | 597.0 | 954.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[4] | -1.137 | 2.894 | -6.794 | 4.092 | 0.120 | 0.085 | 577.0 | 892.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[5] | 3.997 | 2.957 | -1.706 | 9.342 | 0.119 | 0.084 | 622.0 | 897.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[6] | -5.603 | 2.911 | -10.905 | -0.138 | 0.121 | 0.086 | 578.0 | 861.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[7] | 7.542 | 2.870 | 2.339 | 12.969 | 0.120 | 0.085 | 573.0 | 811.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[8] | -1.304 | 2.956 | -7.137 | 3.645 | 0.119 | 0.084 | 624.0 | 1208.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[9] | 2.753 | 2.991 | -2.551 | 8.544 | 0.126 | 0.089 | 576.0 | 1070.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[10] | 4.337 | 2.961 | -1.121 | 9.648 | 0.117 | 0.083 | 640.0 | 1054.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[11] | -0.459 | 2.897 | -6.121 | 4.573 | 0.117 | 0.083 | 615.0 | 981.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[12] | 5.253 | 2.977 | -0.125 | 10.722 | 0.119 | 0.084 | 639.0 | 794.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[13] | 0.454 | 3.055 | -5.468 | 6.024 | 0.123 | 0.087 | 620.0 | 1035.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[14] | -1.123 | 3.330 | -7.470 | 4.998 | 0.128 | 0.091 | 683.0 | 969.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[15] | -7.222 | 3.397 | -13.694 | -1.375 | 0.121 | 0.085 | 794.0 | 1080.0 | 1.0 |
bs(year, knots=iknots, intercept=True)[16] | -8.025 | 3.295 | -13.911 | -1.816 | 0.119 | 0.085 | 764.0 | 980.0 | 1.0 |
sigma | 5.944 | 0.147 | 5.673 | 6.220 | 0.003 | 0.002 | 2904.0 | 1417.0 | 1.0 |
We can also use az.plot_trace()
to visualize the marginal posteriors and the sampling paths. These traces show a stationary random pattern. If these paths were not random stationary, we would be concerned about the convergence of the chains.
; az.plot_trace(idata)
Now we can visualize the fitted basis functions. In addition, we include a thicker black line that represents the dot product between \(B\) and \(w\). This is the contribution of the b-spline to the linear predictor in the model.
= az.extract(idata)
posterior_stacked = posterior_stacked["bs(year, knots=iknots, intercept=True)"].mean("sample").values
wp
= plot_spline_basis(B * wp.T, data["year"].values)
ax ="black", lw=3)
ax.plot(data.year.values, np.dot(B, wp.T), color; plot_knots(knots, ax)
Plot predictions and credible bands
Let’s create a function to plot the predicted mean value as well as credible bands for it.
def plot_predictions(data, idata, model):
# Create a test dataset with observations spanning the whole range of year
= pd.DataFrame({"year": np.linspace(data.year.min(), data.year.max(), num=500)})
new_data
# Predict the day of first blossom
=new_data)
model.predict(idata, data
= az.extract_dataset(idata)
posterior_stacked # Extract these predictions
= posterior_stacked["mu"]
y_hat
# Compute the mean of the predictions, plotted as a single line.
= y_hat.mean("sample")
y_hat_mean
# Compute 94% credible intervals for the predictions, plotted as bands
= np.quantile(y_hat, [0.03, 0.97], axis=1)
hdi_data
# Plot obserevd data
= plot_scatter(data)
ax
# Plot predicted line
"year"], y_hat_mean, color="firebrick")
ax.plot(new_data[
# Plot credibility bands
"year"], hdi_data[0], hdi_data[1], alpha=0.4, color="firebrick")
ax.fill_between(new_data[
# Add knots
plot_knots(knots, ax)
return ax
; plot_predictions(data, idata, model)
/tmp/ipykernel_46679/4286558085.py:8: FutureWarning: extract_dataset has been deprecated, please use extract
posterior_stacked = az.extract_dataset(idata)
Advanced: Watch out the underlying design matrix
We can write linear regression models in matrix form as
\[ \mathbf{y} = \mathbf{X}\boldsymbol{\beta} \]
where \(\mathbf{y}\) is the response column vector of shape \((n, 1)\). \(\mathbf{X}\) is the design matrix that contains the values of the predictors for all the observations, of shape \((n, p)\). And \(\boldsymbol{\beta}\) is the column vector of regression coefficients of shape \((n, 1)\).
Because it’s not something that you’re supposed to consult regularly, Bambi does not expose the design matrix. However, with a some knowledge of the internals, it is possible to have access to it:
round(model.components["mu"].design.common.design_matrix, 3) np.
array([[1. , 1. , 0. , ..., 0. , 0. , 0. ],
[1. , 0.96 , 0.039, ..., 0. , 0. , 0. ],
[1. , 0.767, 0.221, ..., 0. , 0. , 0. ],
...,
[1. , 0. , 0. , ..., 0.002, 0.097, 0.902],
[1. , 0. , 0. , ..., 0. , 0.05 , 0.95 ],
[1. , 0. , 0. , ..., 0. , 0. , 1. ]])
Let’s have a look at its shape:
"mu"].design.common.design_matrix.shape model.components[
(827, 18)
827 is the number of years we have data for, and 18 is the number of predictors/coefficients in the model. We have the first column of ones due to the Intercept
term. Then, there are sixteen columns associated with the the basis functions. And finally, one extra column because we used span_intercept=True
when calling the function bs()
in the model formula.
Now we could compute the rank of the design matrix to check whether all the columns are linearly independent.
"mu"].design.common.design_matrix) np.linalg.matrix_rank(model.components[
17
Since \(\text{rank}(\mathbf{X})\) is smaller than the number of columns, we conclude the columns in \(\mathbf{X}\) are not linearly independent.
If we have a second look at our code, we are going to figure out we’re spanning the intercept twice. The first time with the intercept term itself, and the second time in the spline basis.
This would have been a huge problem in a maximum likelihod estimation approach – we would have obtained an error instead of some parameter estimates. However, since we are doing Bayesian modeling, our priors ensured we obtain our regularized parameter estimates and everything seemed to work pretty well.
Nevertheless, we can still do better. Why would we want to span the intercept twice? Let’s create and fit the model again, this time without spanning the intercept in the spline basis.
# Note we use the same priors
= bmb.Model("doy ~ bs(year, knots=iknots)", data, priors=priors)
model_new = model_new.fit(random_seed=SEED, idata_kwargs={"log_likelihood": True}) idata_new
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [sigma, Intercept, bs(year, knots=iknots)]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 9 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
And let’s have a look at the summary
az.summary(idata_new)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 102.358 | 1.889 | 98.707 | 105.711 | 0.084 | 0.060 | 505.0 | 779.0 | 1.0 |
bs(year, knots=iknots)[0] | -0.875 | 3.832 | -7.709 | 6.539 | 0.135 | 0.096 | 807.0 | 1356.0 | 1.0 |
bs(year, knots=iknots)[1] | 0.430 | 2.954 | -4.934 | 6.052 | 0.083 | 0.065 | 1284.0 | 1460.0 | 1.0 |
bs(year, knots=iknots)[2] | 5.714 | 2.604 | 0.465 | 10.248 | 0.097 | 0.069 | 720.0 | 1266.0 | 1.0 |
bs(year, knots=iknots)[3] | 0.226 | 2.446 | -4.181 | 4.810 | 0.092 | 0.065 | 708.0 | 1168.0 | 1.0 |
bs(year, knots=iknots)[4] | 5.279 | 2.590 | 0.500 | 10.089 | 0.093 | 0.066 | 777.0 | 1304.0 | 1.0 |
bs(year, knots=iknots)[5] | -4.341 | 2.412 | -8.440 | 0.732 | 0.089 | 0.063 | 736.0 | 1127.0 | 1.0 |
bs(year, knots=iknots)[6] | 8.796 | 2.421 | 4.549 | 13.533 | 0.088 | 0.062 | 753.0 | 1246.0 | 1.0 |
bs(year, knots=iknots)[7] | 0.029 | 2.521 | -4.922 | 4.602 | 0.093 | 0.066 | 732.0 | 1119.0 | 1.0 |
bs(year, knots=iknots)[8] | 4.004 | 2.551 | -0.700 | 8.912 | 0.091 | 0.064 | 790.0 | 1119.0 | 1.0 |
bs(year, knots=iknots)[9] | 5.651 | 2.563 | 1.126 | 10.653 | 0.092 | 0.065 | 773.0 | 1350.0 | 1.0 |
bs(year, knots=iknots)[10] | 0.860 | 2.508 | -3.605 | 5.611 | 0.091 | 0.065 | 756.0 | 1174.0 | 1.0 |
bs(year, knots=iknots)[11] | 6.501 | 2.544 | 1.390 | 10.852 | 0.091 | 0.064 | 783.0 | 1127.0 | 1.0 |
bs(year, knots=iknots)[12] | 1.716 | 2.622 | -3.177 | 6.699 | 0.091 | 0.065 | 824.0 | 1266.0 | 1.0 |
bs(year, knots=iknots)[13] | 0.118 | 2.980 | -5.096 | 5.831 | 0.095 | 0.067 | 974.0 | 1329.0 | 1.0 |
bs(year, knots=iknots)[14] | -5.901 | 3.032 | -11.364 | -0.319 | 0.092 | 0.065 | 1087.0 | 1447.0 | 1.0 |
bs(year, knots=iknots)[15] | -6.734 | 2.977 | -12.309 | -1.337 | 0.101 | 0.071 | 883.0 | 1399.0 | 1.0 |
sigma | 5.941 | 0.152 | 5.657 | 6.219 | 0.003 | 0.002 | 2631.0 | 1420.0 | 1.0 |
There are a couple of things to remark here
- There are 16 coefficients associated with the b-spline now because we’re not spanning the intercept.
- The ESS numbers have improved in all cases. Notice the sampler isn’t raising any warning about low ESS.
- r_hat coefficeints are still 1.
We can also compare the sampling times:
idata.posterior.sampling_time
9.64709210395813
idata_new.posterior.sampling_time
9.324542045593262
Sampling times are similar in this particular example. But in general, we expect the sampler to run faster when there aren’t structural dependencies in the design matrix.
And what about predictions?
; plot_predictions(data, idata_new, model_new)
/tmp/ipykernel_46679/4286558085.py:8: FutureWarning: extract_dataset has been deprecated, please use extract
posterior_stacked = az.extract_dataset(idata)
And model comparison?
= {"Original": idata, "New": idata_new}
models_dict = az.compare(models_dict)
df_compare df_compare
rank | elpd_loo | p_loo | elpd_diff | weight | se | dse | warning | scale | |
---|---|---|---|---|---|---|---|---|---|
New | 0 | -2657.850604 | 15.926137 | 0.000000 | 1.0 | 21.193558 | 0.000000 | False | log |
Original | 1 | -2658.586671 | 16.913997 | 0.736068 | 0.0 | 21.192815 | 0.549973 | False | log |
=False); az.plot_compare(df_compare, insample_dev
Finally let’s check influential points according to the k-hat value
# Compute pointwise LOO
= az.loo(idata, pointwise=True)
loo_1 = az.loo(idata_new, pointwise=True) loo_2
# plot kappa values
; az.plot_khat(loo_1.pareto_k)
; az.plot_khat(loo_2.pareto_k)
Final comments
Another option could have been to use stronger priors on the coefficients associated with the spline functions. For example, the example written in PyMC uses \(\text{Normal}(0, 3)\) priors on them instead of \(\text{Normal}(0, 10)\).
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun May 26 2024
Python implementation: CPython
Python version : 3.11.9
IPython version : 8.24.0
pandas : 2.2.2
bambi : 0.13.1.dev39+gb7d6a6cb
numpy : 1.26.4
arviz : 0.18.0
matplotlib: 3.8.4
Watermark: 2.4.3