Skip to content

Commit b05301c

Browse files
authored
update bart example to use pymc-bart package (#416)
* update bart example to use pymc-bart package * update references
1 parent f603fe3 commit b05301c

File tree

3 files changed

+51
-41
lines changed

3 files changed

+51
-41
lines changed

examples/case_studies/BART_introduction.ipynb

Lines changed: 32 additions & 33 deletions
Large diffs are not rendered by default.

examples/references.bib

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,16 @@ @article{nuijten2015default
380380
year = {2015},
381381
publisher = {Springer}
382382
}
383+
@misc{quiroga2022bart,
384+
title = {Bayesian additive regression trees for probabilistic programming},
385+
author = {Quiroga, Miriana and Garay, Pablo G and Alonso, Juan M. and Loyola, Juan Martin and Martin, Osvaldo A},
386+
publisher = {arXiv},
387+
doi = {10.48550/ARXIV.2206.03619},
388+
url = {https://arxiv.org/abs/2206.03619},
389+
keywords = {Computation (stat.CO), FOS: Computer and information sciences, FOS: Computer and information sciences},
390+
year = {2022},
391+
copyright = {Creative Commons Attribution Share Alike 4.0 International}
392+
}
383393
@book{rasmussen2003gaussian,
384394
title = {Gaussian Processes for Machine Learning},
385395
author = {Rasmussen, Carl Edward and Williams, Christopher K. I.},

myst_nbs/case_studies/BART_introduction.myst.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jupytext:
66
format_version: 0.13
77
jupytext_version: 1.13.7
88
kernelspec:
9-
display_name: Python 3.9.7 ('base')
9+
display_name: Python 3 (ipykernel)
1010
language: python
1111
name: python3
1212
---
@@ -27,7 +27,7 @@ import matplotlib.pyplot as plt
2727
import numpy as np
2828
import pandas as pd
2929
import pymc as pm
30-
import pymc_experimental as pmx
30+
import pymc_bart as pmb
3131
3232
print(f"Running on PyMC v{pm.__version__}")
3333
```
@@ -86,7 +86,7 @@ In PyMC a BART variable can be defined very similar to other random variables. O
8686

8787
```{code-cell} ipython3
8888
with pm.Model() as model_coal:
89-
μ_ = pmx.BART("μ_", X=x_data, Y=y_data, m=20)
89+
μ_ = pmb.BART("μ_", X=x_data, Y=y_data, m=20)
9090
μ = pm.Deterministic("μ", np.abs(μ_))
9191
y_pred = pm.Poisson("y_pred", mu=μ, observed=y_data)
9292
idata_coal = pm.sample(random_seed=RANDOM_SEED)
@@ -113,7 +113,7 @@ In the previous plot the white line is the median over 4000 posterior draws, and
113113
The following figure shows two samples from the posterior of $\mu$. We can see that these functions are not smooth. This is fine and is a direct consequence of using regression trees. Trees can be seen as a way to represent stepwise functions, and a sum of stepwise functions is just another stepwise function. Thus, when using BART we just need to know that we are assuming that a stepwise function is a good enough approximation for our problem. In practice this is often the case because we sum over many trees, usually values like 50, 100 or 200. Additionally, we often average over the posterior distribution. All this makes the "steps smoother", even when we never really have an smooth function as for example with Gaussian processes (splines). A nice theoretical result, tells us that in the limit of $m \to \infty$ the BART prior converges to a [nowheredifferentiable](https://en.wikipedia.org/wiki/Weierstrass_function) Gaussian process.
114114

115115
```{code-cell} ipython3
116-
plt.step(x_data, np.exp(pmx.bart.predict(idata_coal, rng, x_data, size=2).T));
116+
plt.step(x_data, np.exp(pmb.predict(idata_coal, rng, x_data, size=2).squeeze().T));
117117
```
118118

119119
To gain further intuition the next figures show 3 of the `m` trees. As we can see these are definitely not very good approximators by themselves. inspecting individuals trees is generally not necessary. We are just showing them here to generate intuition about BART.
@@ -143,7 +143,7 @@ Y = bikes["count"]
143143
```{code-cell} ipython3
144144
with pm.Model() as model_bikes:
145145
σ = pm.HalfNormal("σ", Y.std())
146-
μ = pmx.BART("μ", X, Y, m=50)
146+
μ = pmb.BART("μ", X, Y, m=50)
147147
y = pm.Normal("y", μ, σ, observed=Y)
148148
idata_bikes = pm.sample(random_seed=RANDOM_SEED)
149149
```
@@ -155,7 +155,7 @@ with pm.Model() as model_bikes:
155155
To help us interpret the results of our model we are going to use partial dependence plot. This is a type of plot that shows the marginal effect that one covariate has on the predicted variable. That is, what is the effect that a covariate $X_i$ has of $Y$ while we average over all the other covariates ($X_j, \forall j \not = i$). This type of plot are not exclusive of BART. But they are often used in the BART literature. PyMC provides an utility function to make this plot from the inference data.
156156

157157
```{code-cell} ipython3
158-
pmx.bart.plot_dependence(idata_bikes, X=X, Y=Y, grid=(2, 2), var_discrete=[3]);
158+
pmb.plot_dependence(idata_bikes, X=X, Y=Y, grid=(2, 2), var_discrete=[3]);
159159
```
160160

161161
From this plot we can see the main effect of each covariate on the predicted value. This is very useful we can recover complex relationship beyond monotonic increasing or decreasing effects. For example for the `hour` covariate we can see two peaks around 8 and and 17 hs and a minimum at midnight.
@@ -176,12 +176,13 @@ Additionally, we provide a novel method to assess the variable importance. You c
176176

177177
```{code-cell} ipython3
178178
labels = ["hour", "temperature", "humidity", "workingday"]
179-
pmx.bart.utils.plot_variable_importance(idata_bikes, X.values, labels, samples=100);
179+
pmb.plot_variable_importance(idata_bikes, X.values, labels, samples=100);
180180
```
181181

182182
## Authors
183183
* Authored by Osvaldo Martin in Dec, 2021 ([pymc-examples#259](https://github.com/pymc-devs/pymc-examples/pull/259))
184184
* Updated by Osvaldo Martin in May, 2022 ([pymc-examples#323](https://github.com/pymc-devs/pymc-examples/pull/323))
185+
* Updated by Osvaldo Martin in Sep, 2022
185186

186187
+++
187188

@@ -190,8 +191,8 @@ pmx.bart.utils.plot_variable_importance(idata_bikes, X.values, labels, samples=1
190191
:::{bibliography}
191192
:filter: docname in docnames
192193

193-
martin2018bayesian
194194
martin2021bayesian
195+
quiroga2022bart
195196
:::
196197

197198
+++

0 commit comments

Comments
 (0)