Skip to content

Splines Tutorial -- Add section about predicting on new data #771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,536 changes: 1,191 additions & 345 deletions examples/howto/spline.ipynb

Large diffs are not rendered by default.

195 changes: 182 additions & 13 deletions examples/howto/spline.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ jupytext:
format_name: myst
format_version: 0.13
kernelspec:
display_name: Python 3 (ipykernel)
display_name: pymc-examples
language: python
name: python3
name: pymc-examples
---

(spline)=
Expand Down Expand Up @@ -43,14 +43,15 @@ import numpy as np
import pandas as pd
import pymc as pm

from patsy import dmatrix
from patsy import build_design_matrices, dmatrix
```

```{code-cell} ipython3
%matplotlib inline
%config InlineBackend.figure_format = "retina"

RANDOM_SEED = 8927
seed = sum(map(ord, "splines"))
rng = np.random.default_rng(seed)
az.style.use("arviz-darkgrid")
```

Expand Down Expand Up @@ -84,7 +85,12 @@ If we visualize the data, it is clear that there a lot of annual variation, but

```{code-cell} ipython3
blossom_data.plot.scatter(
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
"year",
"doy",
color="cornflowerblue",
s=10,
title="Cherry Blossom Data",
ylabel="Days in bloom",
);
```

Expand All @@ -106,18 +112,23 @@ The spline will have 15 *knots*, splitting the year into 16 sections (including

```{code-cell} ipython3
num_knots = 15
knot_list = np.quantile(blossom_data.year, np.linspace(0, 1, num_knots))
knot_list = np.percentile(blossom_data.year, np.linspace(0, 100, num_knots + 2))[1:-1]
knot_list
```

Below is a plot of the locations of the knots over the data.

```{code-cell} ipython3
blossom_data.plot.scatter(
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
"year",
"doy",
color="cornflowerblue",
s=10,
title="Cherry Blossom Data",
ylabel="Day of Year",
)
for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4);
plt.gca().axvline(knot, color="grey", alpha=0.4)
```

We can use `patsy` to create the matrix $B$ that will be the b-spline basis for the regression.
Expand All @@ -128,7 +139,7 @@ The degree is set to 3 to create a cubic b-spline.

B = dmatrix(
"bs(year, knots=knots, degree=3, include_intercept=True) - 1",
{"year": blossom_data.year.values, "knots": knot_list[1:-1]},
{"year": blossom_data.year.values, "knots": knot_list},
)
B
```
Expand Down Expand Up @@ -160,9 +171,14 @@ COORDS = {"splines": np.arange(B.shape[1])}
with pm.Model(coords=COORDS) as spline_model:
a = pm.Normal("a", 100, 5)
w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")
mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T))

mu = pm.Deterministic(
"mu",
a + pm.math.dot(np.asarray(B, order="F"), w.T),
)
sigma = pm.Exponential("sigma", 1)
D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs")

D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy)
```

```{code-cell} ipython3
Expand All @@ -172,7 +188,15 @@ pm.model_to_graphviz(spline_model)
```{code-cell} ipython3
with spline_model:
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4))
idata.extend(
pm.sample(
nuts_sampler="nutpie",
draws=1000,
tune=1000,
random_seed=rng,
chains=4,
)
)
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
```

Expand Down Expand Up @@ -230,7 +254,7 @@ spline_df_merged.plot("year", "value", c="black", lw=2, ax=plt.gca())
plt.legend(title="Spline Index", loc="lower center", fontsize=8, ncol=6)

for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4);
plt.gca().axvline(knot, color="grey", alpha=0.4)
```

### Model predictions
Expand Down Expand Up @@ -267,6 +291,150 @@ plt.fill_between(
);
```

## Predicting on new data

Now imagine we got a new data set, with the same range of years as the original data set, and we want to get predictions for this new data set with our fitted model. We can do this with the classic PyMC workflow of `Data` containers and `set_data` method.

Before we get there though, let's note that we didn't say the new data set contains *new* years, i.e out-of-sample years. And that's on purpose, because splines can't extrapolate beyond the range of the data set used to fit the model -- hence their limitation for time series analysis. On data ranges previously seen though, that's no problem.

That precision out of the way, let's redefine our model, this time adding `Data` containers.

```{code-cell} ipython3
COORDS = {"obs": blossom_data.index}
```

```{code-cell} ipython3
with pm.Model(coords=COORDS) as spline_model:
year_data = pm.Data("year", blossom_data.year)
doy = pm.Data("doy", blossom_data.doy)

# intercept
a = pm.Normal("a", 100, 5)

# Create spline bases & coefficients
## Store knots & design matrix for prediction
spline_model.knots = np.percentile(year_data.eval(), np.linspace(0, 100, num_knots + 2))[1:-1]
spline_model.dm = dmatrix(
"bs(x, knots=spline_model.knots, degree=3, include_intercept=False) - 1",
{"x": year_data.eval()},
)
spline_model.add_coords({"spline": np.arange(spline_model.dm.shape[1])})
splines_basis = pm.Data("splines_basis", np.asarray(spline_model.dm), dims=("obs", "spline"))
w = pm.Normal("w", mu=0, sigma=3, dims="spline")

mu = pm.Deterministic(
"mu",
a + pm.math.dot(splines_basis, w),
)
sigma = pm.Exponential("sigma", 1)

D = pm.Normal("D", mu=mu, sigma=sigma, observed=doy)
```

```{code-cell} ipython3
pm.model_to_graphviz(spline_model)
```

```{code-cell} ipython3
with spline_model:
idata = pm.sample(
nuts_sampler="nutpie",
random_seed=rng,
)
idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng))
```

Now we can swap out the data and update the design matrix with the new data:

```{code-cell} ipython3
new_blossom_data = (
blossom_data.sample(50, random_state=rng).sort_values("year").reset_index(drop=True)
)

# update design matrix with new data
year_data_new = new_blossom_data.year.to_numpy()
dm_new = build_design_matrices([spline_model.dm.design_info], {"x": year_data_new})[0]
```

Use `set_data` to update the data in the model:

```{code-cell} ipython3
with spline_model:
pm.set_data(
new_data={
"year": year_data_new,
"doy": new_blossom_data.doy.to_numpy(),
"splines_basis": np.asarray(dm_new),
},
coords={
"obs": new_blossom_data.index,
},
)
```

And all that's left is to sample from the posterior predictive distribution:

```{code-cell} ipython3
with spline_model:
preds = pm.sample_posterior_predictive(idata, var_names=["mu"])
```

Plot the predictions, to check if everything went well:

```{code-cell} ipython3
_, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)

blossom_data.plot.scatter(
"year",
"doy",
color="cornflowerblue",
s=10,
title="Posterior predictions",
ylabel="Days in bloom",
ax=axes[0],
)
axes[0].vlines(
spline_model.knots,
blossom_data.doy.min(),
blossom_data.doy.max(),
color="grey",
alpha=0.4,
)
axes[0].plot(
blossom_data.year,
idata.posterior["mu"].mean(("chain", "draw")),
color="firebrick",
)
az.plot_hdi(blossom_data.year, idata.posterior["mu"], color="firebrick", ax=axes[0])

new_blossom_data.plot.scatter(
"year",
"doy",
color="cornflowerblue",
s=10,
title="Predictions on new data",
ylabel="Days in bloom",
ax=axes[1],
)
axes[1].vlines(
spline_model.knots,
blossom_data.doy.min(),
blossom_data.doy.max(),
color="grey",
alpha=0.4,
)
axes[1].plot(
new_blossom_data.year,
preds.posterior_predictive.mu.mean(("chain", "draw")),
color="firebrick",
)
az.plot_hdi(new_blossom_data.year, preds.posterior_predictive.mu, color="firebrick", ax=axes[1]);
```

And... voilà! Granted, this example is not the most realistic one, but we trust you to adapt it to your wildest dreams ;)

+++

## References

:::{bibliography}
Expand All @@ -280,6 +448,7 @@ plt.fill_between(
- Created by Joshua Cook
- Updated by Tyler James Burch
- Updated by Chris Fonnesbeck
- Predictions on new data added by Alex Andorra

+++

Expand Down