Skip to content

update bart example to use pymc-bart package #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions examples/case_studies/BART_introduction.ipynb

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions examples/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,16 @@ @article{nuijten2015default
year = {2015},
publisher = {Springer}
}
@misc{quiroga2022bart,
title = {Bayesian additive regression trees for probabilistic programming},
author = {Quiroga, Miriana and Garay, Pablo G and Alonso, Juan M. and Loyola, Juan Martin and Martin, Osvaldo A},
publisher = {arXiv},
doi = {10.48550/ARXIV.2206.03619},
url = {https://arxiv.org/abs/2206.03619},
keywords = {Computation (stat.CO), FOS: Computer and information sciences, FOS: Computer and information sciences},
year = {2022},
copyright = {Creative Commons Attribution Share Alike 4.0 International}
}
@book{rasmussen2003gaussian,
title = {Gaussian Processes for Machine Learning},
author = {Rasmussen, Carl Edward and Williams, Christopher K. I.},
Expand Down
17 changes: 9 additions & 8 deletions myst_nbs/case_studies/BART_introduction.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.13.7
kernelspec:
display_name: Python 3.9.7 ('base')
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand All @@ -27,7 +27,7 @@ import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_experimental as pmx
import pymc_bart as pmb

print(f"Running on PyMC v{pm.__version__}")
```
Expand Down Expand Up @@ -86,7 +86,7 @@ In PyMC a BART variable can be defined very similar to other random variables. O

```{code-cell} ipython3
with pm.Model() as model_coal:
μ_ = pmx.BART("μ_", X=x_data, Y=y_data, m=20)
μ_ = pmb.BART("μ_", X=x_data, Y=y_data, m=20)
μ = pm.Deterministic("μ", np.abs(μ_))
y_pred = pm.Poisson("y_pred", mu=μ, observed=y_data)
idata_coal = pm.sample(random_seed=RANDOM_SEED)
Expand All @@ -113,7 +113,7 @@ In the previous plot the white line is the median over 4000 posterior draws, and
The following figure shows two samples from the posterior of $\mu$. We can see that these functions are not smooth. This is fine and is a direct consequence of using regression trees. Trees can be seen as a way to represent stepwise functions, and a sum of stepwise functions is just another stepwise function. Thus, when using BART we just need to know that we are assuming that a stepwise function is a good enough approximation for our problem. In practice this is often the case because we sum over many trees, usually values like 50, 100 or 200. Additionally, we often average over the posterior distribution. All this makes the "steps smoother", even when we never really have an smooth function as for example with Gaussian processes (splines). A nice theoretical result, tells us that in the limit of $m \to \infty$ the BART prior converges to a [nowheredifferentiable](https://en.wikipedia.org/wiki/Weierstrass_function) Gaussian process.

```{code-cell} ipython3
plt.step(x_data, np.exp(pmx.bart.predict(idata_coal, rng, x_data, size=2).T));
plt.step(x_data, np.exp(pmb.predict(idata_coal, rng, x_data, size=2).squeeze().T));
```

To gain further intuition the next figures show 3 of the `m` trees. As we can see these are definitely not very good approximators by themselves. inspecting individuals trees is generally not necessary. We are just showing them here to generate intuition about BART.
Expand Down Expand Up @@ -143,7 +143,7 @@ Y = bikes["count"]
```{code-cell} ipython3
with pm.Model() as model_bikes:
σ = pm.HalfNormal("σ", Y.std())
μ = pmx.BART("μ", X, Y, m=50)
μ = pmb.BART("μ", X, Y, m=50)
y = pm.Normal("y", μ, σ, observed=Y)
idata_bikes = pm.sample(random_seed=RANDOM_SEED)
```
Expand All @@ -155,7 +155,7 @@ with pm.Model() as model_bikes:
To help us interpret the results of our model we are going to use partial dependence plot. This is a type of plot that shows the marginal effect that one covariate has on the predicted variable. That is, what is the effect that a covariate $X_i$ has of $Y$ while we average over all the other covariates ($X_j, \forall j \not = i$). This type of plot are not exclusive of BART. But they are often used in the BART literature. PyMC provides an utility function to make this plot from the inference data.

```{code-cell} ipython3
pmx.bart.plot_dependence(idata_bikes, X=X, Y=Y, grid=(2, 2), var_discrete=[3]);
pmb.plot_dependence(idata_bikes, X=X, Y=Y, grid=(2, 2), var_discrete=[3]);
```

From this plot we can see the main effect of each covariate on the predicted value. This is very useful we can recover complex relationship beyond monotonic increasing or decreasing effects. For example for the `hour` covariate we can see two peaks around 8 and and 17 hs and a minimum at midnight.
Expand All @@ -176,12 +176,13 @@ Additionally, we provide a novel method to assess the variable importance. You c

```{code-cell} ipython3
labels = ["hour", "temperature", "humidity", "workingday"]
pmx.bart.utils.plot_variable_importance(idata_bikes, X.values, labels, samples=100);
pmb.plot_variable_importance(idata_bikes, X.values, labels, samples=100);
```

## Authors
* Authored by Osvaldo Martin in Dec, 2021 ([pymc-examples#259](https://github.com/pymc-devs/pymc-examples/pull/259))
* Updated by Osvaldo Martin in May, 2022 ([pymc-examples#323](https://github.com/pymc-devs/pymc-examples/pull/323))
* Updated by Osvaldo Martin in Sep, 2022

+++

Expand All @@ -190,8 +191,8 @@ pmx.bart.utils.plot_variable_importance(idata_bikes, X.values, labels, samples=1
:::{bibliography}
:filter: docname in docnames

martin2018bayesian
martin2021bayesian
quiroga2022bart
:::

+++
Expand Down