Regression splines (Cherry blossom example)

This example shows how to specify and fit a spline regression in Bambi. This example is based on this example from the PyMC docs.

import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
az.style.use("arviz-darkgrid")
SEED = 7355608

Load Cherry Blossom data

Richard McElreath popularized the Cherry Blossom dataset in the second edition of his excellent book Statistical Rethinking. This data represents the day in the year when the first bloom is observed for Japanese cherry blossoms between years 801 and 2015. In his book, Richard McElreath uses this dataset to introduce Basis Splines, or B-Splines in short.

Here we use Bambi to fit a linear model using B-Splines with the Cherry Blossom data. This dataset can be loaded with Bambi as follows:

data = bmb.load_data("cherry_blossoms")
data
year doy temp temp_upper temp_lower
0 801 NaN NaN NaN NaN
1 802 NaN NaN NaN NaN
2 803 NaN NaN NaN NaN
3 804 NaN NaN NaN NaN
4 805 NaN NaN NaN NaN
... ... ... ... ... ...
1210 2011 99.0 NaN NaN NaN
1211 2012 101.0 NaN NaN NaN
1212 2013 93.0 NaN NaN NaN
1213 2014 94.0 NaN NaN NaN
1214 2015 93.0 NaN NaN NaN

1215 rows × 5 columns

The variable we are interested in modeling is "doy", which stands for Day of Year. Also notice this variable contains several missing value which are discarded next.

data = data.dropna(subset=["doy"]).reset_index(drop=True)
data.shape
(827, 5)

Explore the data

Let’s get started by creating a scatterplot to explore the values of "doy" for each year in the dataset.

# We create a function because this plot is going to be used again later
def plot_scatter(data, figsize=(10, 6)):
    _, ax = plt.subplots(figsize=figsize)
    ax.scatter(data["year"], data["doy"], alpha=0.4, s=30)
    ax.set_title("Day of the first bloom per year")
    ax.set_xlabel("Year")
    ax.set_ylabel("Days of the first bloom")
    return ax
plot_scatter(data);

We can observe the day of the first bloom ranges between 85 and 125 approximately, which correspond to late March and early May respectively. On average, the first bloom occurs on the 105th day of the year, which is middle April.

Determine knots

The spline will have 15 knots. These knots are the boundaries of the basis functions. These knots split the range of the "year" variable into 16 contiguous sections. The basis functions make up a piecewise continuous polynomial, and so they are enforced to meet at the knots. We use the default degree for each piecewise polynomial, which is 3. The result is known as a cubic spline.

Because of using quantiles and not having observations for all the years in the time window under study, the knots are distributed unevenly over the range of "year" in such a way that the same proportion of values fall between each section.

num_knots = 15
knots = np.quantile(data["year"], np.linspace(0, 1, num_knots))
def plot_knots(knots, ax):
    for knot in knots:
        ax.axvline(knot, color="0.1", alpha=0.4)
    return ax
ax = plot_scatter(data)
plot_knots(knots, ax);

The previous chart makes it easy to see the knots, represented by the vertical lines, are spaced unevenly over the years.

The model

The B-spline model we are about to create is simply a linear regression model with synthetic predictor variables. These predictors are the basis functions that are derived from the original year predictor.

In math notation, we usa a \(\text{Normal}\) distribution for the conditional distribution of \(Y\) when \(X = x_i\), i.e. \(Y_i\), the distribution of the day of the first bloom in a given year.

\[ Y_i \sim \text{Normal}(\mu_i, \sigma) \]

So far, this looks like a regular linear regression model. The next line is where the spline comes into play:

\[ \mu_i = \alpha + \sum_{k=1}^K{w_kB_{k, i}} \]

The line above tells that for each observation \(i\), the mean is influenced by all the basis functions (going from \(k=1\) to \(k=K\)), plus an intercept \(\alpha\). The \(w_k\) values in the summation are the regression coefficients of each of the basis functions, and the \(B_k\) are the values of the basis functions.

Finally, we will be using the following priors

\[ \begin{aligned} \alpha & \sim \text{Normal}(100, 10) \\ w_j & \sim \text{Normal}(0, 10)\\ \sigma & \sim \text{Exponential(1)} \end{aligned} \]

where \(j\) indexes each of the contiguous sections given by the knots

# We only pass the internal knots to the `bs()` function.
iknots = knots[1:-1]

# Define dictionary of priors
priors = {
    "Intercept": bmb.Prior("Normal", mu=100, sigma=10),
    "common": bmb.Prior("Normal", mu=0, sigma=10), 
    "sigma": bmb.Prior("Exponential", lam=1)
}

# Define model
# The intercept=True means the basis also spans the intercept, as originally done in the book example.
model = bmb.Model("doy ~ bs(year, knots=iknots, intercept=True)", data, priors=priors)
model
       Formula: doy ~ bs(year, knots=iknots, intercept=True)
        Family: gaussian
          Link: mu = identity
  Observations: 827
        Priors: 
    target = mu
        Common-level effects
            Intercept ~ Normal(mu: 100.0, sigma: 10.0)
            bs(year, knots=iknots, intercept=True) ~ Normal(mu: 0.0, sigma: 10.0)
        
        Auxiliary parameters
            sigma ~ Exponential(lam: 1.0)

Let’s create a function to plot each of the basis functions in the model.

def plot_spline_basis(basis, year, figsize=(10, 6)):
    df = (
        pd.DataFrame(basis)
        .assign(year=year)
        .melt("year", var_name="basis_idx", value_name="value")
    )

    _, ax = plt.subplots(figsize=figsize)

    for idx in df.basis_idx.unique():
        d = df[df.basis_idx == idx]
        ax.plot(d["year"], d["value"])
    
    return ax

Below, we create a chart to visualize the b-spline basis. The overlap between the functions means that, at any given point in time, the regression function is influenced by more than one basis function. For example, if we look at the year 1200, we can see the regression line is going to be influenced mostly by the violet and brown functions, and to a lesser extent by the green and cyan ones. In summary, this is what enables us to capture local patterns in a smooth fashion.

B = model.response_component.design.common["bs(year, knots=iknots, intercept=True)"]
ax = plot_spline_basis(B, data["year"].values)
plot_knots(knots, ax);

Fit model

Now we fit the model. In Bambi, it is as easy as calling the .fit() method on the Model instance.

# The seed is to make results reproducible
idata = model.fit(random_seed=SEED, idata_kwargs={"log_likelihood": True})
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [doy_sigma, Intercept, bs(year, knots=iknots, intercept=True)]
100.00% [4000/4000 00:32<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 33 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Analisys of the results

It is always good to use az.summary() to verify parameter estimates as well as effective sample sizes and R hat values. In this case, the main goal is not to interpret the coefficients of the basis spline, but analyze the ess and r_hat diagnostics. In first place, effective sample sizes don’t look impressively high. Most of them are between 300 and 700, which is low compared to the 2000 draws obtained. The only exception is the residual standard deviation sigma. Finally, the r_hat diagnostic is not always 1 for all the parameters, indicating there may be some issues with the mix of the chains.

az.summary(idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 103.387 2.444 98.582 107.719 0.131 0.093 348.0 540.0 1.01
bs(year, knots=iknots, intercept=True)[0] -3.074 3.819 -10.477 3.705 0.127 0.090 908.0 1319.0 1.00
bs(year, knots=iknots, intercept=True)[1] -0.841 3.949 -8.290 6.242 0.146 0.103 739.0 1089.0 1.00
bs(year, knots=iknots, intercept=True)[2] -1.167 3.662 -8.245 5.517 0.140 0.099 685.0 935.0 1.00
bs(year, knots=iknots, intercept=True)[3] 4.810 2.987 -0.362 10.721 0.135 0.096 487.0 915.0 1.00
bs(year, knots=iknots, intercept=True)[4] -0.881 2.970 -6.245 4.759 0.137 0.097 472.0 951.0 1.00
bs(year, knots=iknots, intercept=True)[5] 4.277 2.963 -0.901 9.904 0.134 0.095 488.0 1104.0 1.00
bs(year, knots=iknots, intercept=True)[6] -5.350 2.883 -11.223 -0.312 0.137 0.097 439.0 870.0 1.00
bs(year, knots=iknots, intercept=True)[7] 7.786 2.813 2.161 13.013 0.129 0.091 477.0 842.0 1.00
bs(year, knots=iknots, intercept=True)[8] -1.017 2.977 -6.426 4.689 0.141 0.100 445.0 697.0 1.00
bs(year, knots=iknots, intercept=True)[9] 2.927 2.958 -2.100 9.282 0.136 0.096 474.0 809.0 1.00
bs(year, knots=iknots, intercept=True)[10] 4.693 2.990 -0.911 10.137 0.137 0.097 477.0 837.0 1.00
bs(year, knots=iknots, intercept=True)[11] -0.246 2.943 -5.760 5.126 0.133 0.094 490.0 908.0 1.00
bs(year, knots=iknots, intercept=True)[12] 5.548 2.984 0.328 11.413 0.140 0.099 451.0 837.0 1.00
bs(year, knots=iknots, intercept=True)[13] 0.653 3.115 -4.897 6.839 0.132 0.094 557.0 933.0 1.00
bs(year, knots=iknots, intercept=True)[14] -0.778 3.345 -7.165 5.314 0.142 0.101 551.0 981.0 1.00
bs(year, knots=iknots, intercept=True)[15] -7.039 3.527 -13.975 -0.638 0.137 0.097 667.0 1021.0 1.00
bs(year, knots=iknots, intercept=True)[16] -7.711 3.293 -14.579 -2.133 0.135 0.095 595.0 1090.0 1.00
doy_sigma 5.944 0.143 5.671 6.198 0.003 0.002 3031.0 1497.0 1.00

We can also use az.plot_trace() to visualize the marginal posteriors and the sampling paths. These traces show a stationary random pattern. If these paths were not random stationary, we would be concerned about the convergence of the chains.

az.plot_trace(idata);

Now we can visualize the fitted basis functions. In addition, we include a thicker black line that represents the dot product between \(B\) and \(w\). This is the contribution of the b-spline to the linear predictor in the model.

posterior_stacked = az.extract(idata)
wp = posterior_stacked["bs(year, knots=iknots, intercept=True)"].mean("sample").values

ax = plot_spline_basis(B * wp.T, data["year"].values)
ax.plot(data.year.values, np.dot(B, wp.T), color="black", lw=3)
plot_knots(knots, ax);

Plot predictions and credible bands

Let’s create a function to plot the predicted mean value as well as credible bands for it.

def plot_predictions(data, idata, model):
    # Create a test dataset with observations spanning the whole range of year
    new_data = pd.DataFrame({"year": np.linspace(data.year.min(), data.year.max(), num=500)})
    
    # Predict the day of first blossom
    model.predict(idata, data=new_data)

    posterior_stacked =  az.extract_dataset(idata)
    # Extract these predictions
    y_hat = posterior_stacked["doy_mean"]

    # Compute the mean of the predictions, plotted as a single line.
    y_hat_mean = y_hat.mean("sample")

    # Compute 94% credible intervals for the predictions, plotted as bands
    hdi_data = np.quantile(y_hat, [0.03, 0.97], axis=1)

    # Plot obserevd data
    ax = plot_scatter(data)
    
    # Plot predicted line
    ax.plot(new_data["year"], y_hat_mean, color="firebrick")
    
    # Plot credibility bands
    ax.fill_between(new_data["year"], hdi_data[0], hdi_data[1], alpha=0.4, color="firebrick")
    
    # Add knots
    plot_knots(knots, ax)
    
    return ax
plot_predictions(data, idata, model);
/tmp/ipykernel_33590/2247671002.py:8: FutureWarning: extract_dataset has been deprecated, please use extract
  posterior_stacked =  az.extract_dataset(idata)

Advanced: Watch out the underlying design matrix

We can write linear regression models in matrix form as

\[ \mathbf{y} = \mathbf{X}\boldsymbol{\beta} \]

where \(\mathbf{y}\) is the response column vector of shape \((n, 1)\). \(\mathbf{X}\) is the design matrix that contains the values of the predictors for all the observations, of shape \((n, p)\). And \(\boldsymbol{\beta}\) is the column vector of regression coefficients of shape \((n, 1)\).

Because it’s not something that you’re supposed to consult regularly, Bambi does not expose the design matrix. However, with a some knowledge of the internals, it is possible to have access to it:

np.round(model.response_component.design.common.design_matrix, 3)
array([[1.   , 1.   , 0.   , ..., 0.   , 0.   , 0.   ],
       [1.   , 0.96 , 0.039, ..., 0.   , 0.   , 0.   ],
       [1.   , 0.767, 0.221, ..., 0.   , 0.   , 0.   ],
       ...,
       [1.   , 0.   , 0.   , ..., 0.002, 0.097, 0.902],
       [1.   , 0.   , 0.   , ..., 0.   , 0.05 , 0.95 ],
       [1.   , 0.   , 0.   , ..., 0.   , 0.   , 1.   ]])

Let’s have a look at its shape:

model.response_component.design.common.design_matrix.shape
(827, 18)

827 is the number of years we have data for, and 18 is the number of predictors/coefficients in the model. We have the first column of ones due to the Intercept term. Then, there are sixteen columns associated with the the basis functions. And finally, one extra column because we used span_intercept=True when calling the function bs() in the model formula.

Now we could compute the rank of the design matrix to check whether all the columns are linearly independent.

np.linalg.matrix_rank(model.response_component.design.common.design_matrix)
17

Since \(\text{rank}(\mathbf{X})\) is smaller than the number of columns, we conclude the columns in \(\mathbf{X}\) are not linearly independent.

If we have a second look at our code, we are going to figure out we’re spanning the intercept twice. The first time with the intercept term itself, and the second time in the spline basis.

This would have been a huge problem in a maximum likelihod estimation approach – we would have obtained an error instead of some parameter estimates. However, since we are doing Bayesian modeling, our priors ensured we obtain our regularized parameter estimates and everything seemed to work pretty well.

Nevertheless, we can still do better. Why would we want to span the intercept twice? Let’s create and fit the model again, this time without spanning the intercept in the spline basis.

# Note we use the same priors
model_new = bmb.Model("doy ~ bs(year, knots=iknots)", data, priors=priors)
idata_new = model_new.fit(random_seed=SEED, idata_kwargs={"log_likelihood": True})
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [doy_sigma, Intercept, bs(year, knots=iknots)]
100.00% [4000/4000 00:31<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 32 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

And let’s have a look at the summary

az.summary(idata_new)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Intercept 102.367 1.992 98.899 106.358 0.105 0.074 361.0 581.0 1.01
bs(year, knots=iknots)[0] -0.849 3.999 -8.142 6.704 0.164 0.116 591.0 930.0 1.00
bs(year, knots=iknots)[1] 0.394 3.012 -5.253 5.983 0.090 0.063 1132.0 1249.0 1.00
bs(year, knots=iknots)[2] 5.707 2.712 0.074 10.305 0.120 0.085 510.0 1017.0 1.00
bs(year, knots=iknots)[3] 0.216 2.467 -4.358 4.849 0.103 0.073 571.0 1320.0 1.00
bs(year, knots=iknots)[4] 5.237 2.711 0.104 10.568 0.118 0.084 526.0 789.0 1.00
bs(year, knots=iknots)[5] -4.332 2.428 -8.909 0.043 0.105 0.074 535.0 890.0 1.01
bs(year, knots=iknots)[6] 8.788 2.546 3.669 13.310 0.112 0.079 518.0 854.0 1.01
bs(year, knots=iknots)[7] 0.008 2.573 -5.056 4.474 0.112 0.079 525.0 916.0 1.00
bs(year, knots=iknots)[8] 3.980 2.745 -0.716 9.394 0.112 0.079 597.0 927.0 1.00
bs(year, knots=iknots)[9] 5.658 2.559 0.917 10.350 0.109 0.077 552.0 850.0 1.00
bs(year, knots=iknots)[10] 0.801 2.655 -4.092 5.842 0.112 0.079 565.0 956.0 1.00
bs(year, knots=iknots)[11] 6.534 2.578 1.952 11.575 0.112 0.079 531.0 845.0 1.01
bs(year, knots=iknots)[12] 1.703 2.772 -3.154 7.363 0.114 0.081 591.0 1126.0 1.00
bs(year, knots=iknots)[13] 0.190 3.076 -5.277 6.077 0.115 0.081 722.0 1258.0 1.00
bs(year, knots=iknots)[14] -6.026 3.162 -11.645 0.206 0.122 0.086 672.0 1164.0 1.00
bs(year, knots=iknots)[15] -6.715 3.005 -12.485 -1.229 0.118 0.084 641.0 1306.0 1.00
doy_sigma 5.949 0.146 5.674 6.221 0.003 0.002 2287.0 1466.0 1.00

There are a couple of things to remark here

  • There are 16 coefficients associated with the b-spline now because we’re not spanning the intercept.
  • The ESS numbers have improved in all cases. Notice the sampler isn’t raising any warning about low ESS.
  • r_hat coefficeints are still 1.

We can also compare the sampling times:

idata.posterior.sampling_time
32.5815589427948
idata_new.posterior.sampling_time
31.589828729629517

Sampling times are similar in this particular example. But in general, we expect the sampler to run faster when there aren’t structural dependencies in the design matrix.

And what about predictions?

plot_predictions(data, idata_new, model_new);
/tmp/ipykernel_33590/2247671002.py:8: FutureWarning: extract_dataset has been deprecated, please use extract
  posterior_stacked =  az.extract_dataset(idata)

And model comparison?

models_dict = {"Original": idata, "New": idata_new}
df_compare = az.compare(models_dict)
df_compare
rank elpd_loo p_loo elpd_diff weight se dse warning scale
New 0 -2657.859115 15.945629 0.000000 1.000000e+00 21.134973 0.000000 False log
Original 1 -2658.359085 16.652034 0.499969 3.330669e-16 21.173433 0.561943 False log
az.plot_compare(df_compare, insample_dev=False);

Finally let’s check influential points according to the k-hat value

# Compute pointwise LOO
loo_1 = az.loo(idata, pointwise=True)
loo_2 = az.loo(idata_new, pointwise=True)
/tmp/ipykernel_33590/3493983793.py:2: DeprecationWarning: `product` is deprecated as of NumPy 1.25.0, and will be removed in NumPy 2.0. Please use `prod` instead.
  loo_1 = az.loo(idata, pointwise=True)
/tmp/ipykernel_33590/3493983793.py:3: DeprecationWarning: `product` is deprecated as of NumPy 1.25.0, and will be removed in NumPy 2.0. Please use `prod` instead.
  loo_2 = az.loo(idata_new, pointwise=True)
# plot kappa values
az.plot_khat(loo_1.pareto_k);

az.plot_khat(loo_2.pareto_k);

Final comments

Another option could have been to use stronger priors on the coefficients associated with the spline functions. For example, the example written in PyMC uses \(\text{Normal}(0, 3)\) priors on them instead of \(\text{Normal}(0, 10)\).

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Wed Jun 28 2023

Python implementation: CPython
Python version       : 3.10.4
IPython version      : 8.5.0

pandas    : 2.0.2
bambi     : 0.12.0.dev0
arviz     : 0.14.0
numpy     : 1.25.0
matplotlib: 3.6.2

Watermark: 2.3.1