Skip to content

BART: Fully non-parametric curve fit example #519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
505 changes: 505 additions & 0 deletions examples/case_studies/bart_heteroscedasticity.ipynb

Large diffs are not rendered by default.

166 changes: 166 additions & 0 deletions examples/case_studies/bart_heteroscedasticity.myst.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
---
jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
kernelspec:
display_name: pymc-examples-env
language: python
name: python3
---

(bart_heteroscedasticity)=
# Modeling Heteroscedasticity with BART

:::{post} January, 2023
:tags: bart regression
:category: beginner, reference
:author: [Juan Orduz](https://juanitorduz.github.io/)
:::

+++

In this notebook we show how to use BART to model heteroscedasticity as described in Section 4.1 of [`pymc-bart`](https://github.com/pymc-devs/pymc-bart)'s paper {cite:p}`quiroga2022bart`. We use the `marketing` data set provided by the R package `datarium` {cite:p}`kassambara2019datarium`. The idea is to model a marketing channel contribution to sales as a function of budget.

```{code-cell} ipython3
:tags: []

import os

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
```

```{code-cell} ipython3
:tags: []

%config InlineBackend.figure_format = "retina"
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [10, 6]
rng = np.random.default_rng(42)
```

## Read Data

```{code-cell} ipython3
try:
df = pd.read_csv(os.path.join("..", "data", "marketing.csv"), sep=";", decimal=",")
except FileNotFoundError:
df = pd.read_csv(pm.get_data("marketing.csv"), sep=";", decimal=",")

n_obs = df.shape[0]

df.head()
```

## EDA

We start by looking into the data. We are going to focus on *Youtube*.

```{code-cell} ipython3
fig, ax = plt.subplots()
ax.plot(df["youtube"], df["sales"], "o", c="C0")
ax.set(title="Sales as a function of Youtube budget", xlabel="budget", ylabel="sales");
```

We clearly see that both the mean and variance are increasing as a function of budget. One possibility is to manually select an explicit parametrization of these functions, e.g. square root or logarithm. However, in this example we want to learn these functions from the data using a BART model.

+++

## Model Specification

We proceed to prepare the data for modeling. We are going to use the `budget` as the predictor and `sales` as the response.

```{code-cell} ipython3
X = df["youtube"].to_numpy().reshape(-1, 1)
Y = df["sales"].to_numpy()
```

Next, we specify the model. Note that we just need one BART distribution which can be vectorized to model both the mean and variance. We use a Gamma distribution as likelihood as we expect the sales to be positive.

```{code-cell} ipython3
with pm.Model() as model_marketing_full:
w = pmb.BART(name="w", X=X, Y=Y, m=200, shape=(2, n_obs))
y = pm.Gamma(name="y", mu=w[0], sigma=pm.math.abs(w[1]), observed=Y)

pm.model_to_graphviz(model=model_marketing_full)
```

We now fit the model.

```{code-cell} ipython3
with model_marketing_full:
idata_marketing_full = pm.sample(random_seed=rng)
posterior_predictive_marketing_full = pm.sample_posterior_predictive(
trace=idata_marketing_full, random_seed=rng
)
```

## Results

We can now visualize the posterior predictive distribution of the mean and the likelihood.

```{code-cell} ipython3
posterior_mean = idata_marketing_full.posterior["w"].mean(dim=("chain", "draw"))[0]

w_hdi = az.hdi(ary=idata_marketing_full, group="posterior", var_names=["w"])

pps = az.extract(
posterior_predictive_marketing_full, group="posterior_predictive", var_names=["y"]
).T
```

```{code-cell} ipython3
idx = np.argsort(X[:, 0])


fig, ax = plt.subplots()
az.plot_hdi(x=X[:, 0], y=pps, ax=ax, fill_kwargs={"alpha": 0.3, "label": r"Likelihood $94\%$ HDI"})
az.plot_hdi(
x=X[:, 0],
hdi_data=w_hdi["w"].sel(w_dim_0=0),
ax=ax,
fill_kwargs={"alpha": 0.6, "label": r"Mean $94\%$ HDI"},
)
ax.plot(X[:, 0][idx], posterior_mean[idx], c="black", lw=3, label="Posterior Mean")
ax.plot(df["youtube"], df["sales"], "o", c="C0", label="Raw Data")
ax.legend(loc="upper left")
ax.set(
title="Sales as a function of Youtube budget - Posterior Predictive",
xlabel="budget",
ylabel="sales",
);
```

The fit looks good! In fact, we see that the mean and variance increase as a function of the budget.

+++

## Authors
- Authored by [Juan Orduz](https://juanitorduz.github.io/) in February 2023

+++

## References
:::{bibliography}
:filter: docname in docnames
:::

+++

## Watermark

```{code-cell} ipython3
:tags: []

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor
```

:::{include} ../page_footer.md
:::
201 changes: 201 additions & 0 deletions examples/data/marketing.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
youtube;facebook;newspaper;sales
276,12;45,36;83,04;26,52
53,40;47,16;54,12;12,48
20,64;55,08;83,16;11,16
181,80;49,56;70,20;22,20
216,96;12,96;70,08;15,48
10,44;58,68;90,00;8,64
69,00;39,36;28,20;14,16
144,24;23,52;13,92;15,84
10,32;2,52;1,20;5,76
239,76;3,12;25,44;12,72
79,32;6,96;29,04;10,32
257,64;28,80;4,80;20,88
28,56;42,12;79,08;11,04
117,00;9,12;8,64;11,64
244,92;39,48;55,20;22,80
234,48;57,24;63,48;26,88
81,36;43,92;136,80;15,00
337,68;47,52;66,96;29,28
83,04;24,60;21,96;13,56
176,76;28,68;22,92;17,52
262,08;33,24;64,08;21,60
284,88;6,12;28,20;15,00
15,84;19,08;59,52;6,72
273,96;20,28;31,44;18,60
74,76;15,12;21,96;11,64
315,48;4,20;23,40;14,40
171,48;35,16;15,12;18,00
288,12;20,04;27,48;19,08
298,56;32,52;27,48;22,68
84,72;19,20;48,96;12,60
351,48;33,96;51,84;25,68
135,48;20,88;46,32;14,28
116,64;1,80;36,00;11,52
318,72;24,00;0,36;20,88
114,84;1,68;8,88;11,40
348,84;4,92;10,20;15,36
320,28;52,56;6,00;30,48
89,64;59,28;54,84;17,64
51,72;32,04;42,12;12,12
273,60;45,24;38,40;25,80
243,00;26,76;37,92;19,92
212,40;40,08;46,44;20,52
352,32;33,24;2,16;24,84
248,28;10,08;31,68;15,48
30,12;30,84;51,96;10,20
210,12;27,00;37,80;17,88
107,64;11,88;42,84;12,72
287,88;49,80;22,20;27,84
272,64;18,96;59,88;17,76
80,28;14,04;44,16;11,64
239,76;3,72;41,52;13,68
120,48;11,52;4,32;12,84
259,68;50,04;47,52;27,12
219,12;55,44;70,44;25,44
315,24;34,56;19,08;24,24
238,68;59,28;72,00;28,44
8,76;33,72;49,68;6,60
163,44;23,04;19,92;15,84
252,96;59,52;45,24;28,56
252,84;35,40;11,16;22,08
64,20;2,40;25,68;9,72
313,56;51,24;65,64;29,04
287,16;18,60;32,76;18,84
123,24;35,52;10,08;16,80
157,32;51,36;34,68;21,60
82,80;11,16;1,08;11,16
37,80;29,52;2,64;11,40
167,16;17,40;12,24;16,08
284,88;33,00;13,20;22,68
260,16;52,68;32,64;26,76
238,92;36,72;46,44;21,96
131,76;17,16;38,04;14,88
32,16;39,60;23,16;10,56
155,28;6,84;37,56;13,20
256,08;29,52;15,72;20,40
20,28;52,44;107,28;10,44
33,00;1,92;24,84;8,28
144,60;34,20;17,04;17,04
6,48;35,88;11,28;6,36
139,20;9,24;27,72;13,20
91,68;32,04;26,76;14,16
287,76;4,92;44,28;14,76
90,36;24,36;39,00;13,56
82,08;53,40;42,72;16,32
256,20;51,60;40,56;26,04
231,84;22,08;78,84;18,24
91,56;33,00;19,20;14,40
132,84;48,72;75,84;19,20
105,96;30,60;88,08;15,48
131,76;57,36;61,68;20,04
161,16;5,88;11,16;13,44
34,32;1,80;39,60;8,76
261,24;40,20;70,80;23,28
301,08;43,80;86,76;26,64
128,88;16,80;13,08;13,80
195,96;37,92;63,48;20,28
237,12;4,20;7,08;14,04
221,88;25,20;26,40;18,60
347,64;50,76;61,44;30,48
162,24;50,04;55,08;20,64
266,88;5,16;59,76;14,04
355,68;43,56;121,08;28,56
336,24;12,12;25,68;17,76
225,48;20,64;21,48;17,64
285,84;41,16;6,36;24,84
165,48;55,68;70,80;23,04
30,00;13,20;35,64;8,64
108,48;0,36;27,84;10,44
15,72;0,48;30,72;6,36
306,48;32,28;6,60;23,76
270,96;9,84;67,80;16,08
290,04;45,60;27,84;26,16
210,84;18,48;2,88;16,92
251,52;24,72;12,84;19,08
93,84;56,16;41,40;17,52
90,12;42,00;63,24;15,12
167,04;17,16;30,72;14,64
91,68;0,96;17,76;11,28
150,84;44,28;95,04;19,08
23,28;19,20;26,76;7,92
169,56;32,16;55,44;18,60
22,56;26,04;60,48;8,40
268,80;2,88;18,72;13,92
147,72;41,52;14,88;18,24
275,40;38,76;89,04;23,64
104,64;14,16;31,08;12,72
9,36;46,68;60,72;7,92
96,24;0,00;11,04;10,56
264,36;58,80;3,84;29,64
71,52;14,40;51,72;11,64
0,84;47,52;10,44;1,92
318,24;3,48;51,60;15,24
10,08;32,64;2,52;6,84
263,76;40,20;54,12;23,52
44,28;46,32;78,72;12,96
57,96;56,40;10,20;13,92
30,72;46,80;11,16;11,40
328,44;34,68;71,64;24,96
51,60;31,08;24,60;11,52
221,88;52,68;2,04;24,84
88,08;20,40;15,48;13,08
232,44;42,48;90,72;23,04
264,60;39,84;45,48;24,12
125,52;6,84;41,28;12,48
115,44;17,76;46,68;13,68
168,36;2,28;10,80;12,36
288,12;8,76;10,44;15,84
291,84;58,80;53,16;30,48
45,60;48,36;14,28;13,08
53,64;30,96;24,72;12,12
336,84;16,68;44,40;19,32
145,20;10,08;58,44;13,92
237,12;27,96;17,04;19,92
205,56;47,64;45,24;22,80
225,36;25,32;11,40;18,72
4,92;13,92;6,84;3,84
112,68;52,20;60,60;18,36
179,76;1,56;29,16;12,12
14,04;44,28;54,24;8,76
158,04;22,08;41,52;15,48
207,00;21,72;36,84;17,28
102,84;42,96;59,16;15,96
226,08;21,72;30,72;17,88
196,20;44,16;8,88;21,60
140,64;17,64;6,48;14,28
281,40;4,08;101,76;14,28
21,48;45,12;25,92;9,60
248,16;6,24;23,28;14,64
258,48;28,32;69,12;20,52
341,16;12,72;7,68;18,00
60,00;13,92;22,08;10,08
197,40;25,08;56,88;17,40
23,52;24,12;20,40;9,12
202,08;8,52;15,36;14,04
266,88;4,08;15,72;13,80
332,28;58,68;50,16;32,40
298,08;36,24;24,36;24,24
204,24;9,36;42,24;14,04
332,04;2,76;28,44;14,16
198,72;12,00;21,12;15,12
187,92;3,12;9,96;12,60
262,20;6,48;32,88;14,64
67,44;6,84;35,64;10,44
345,12;51,60;86,16;31,44
304,56;25,56;36,00;21,12
246,00;54,12;23,52;27,12
167,40;2,52;31,92;12,36
229,32;34,44;21,84;20,76
343,20;16,68;4,44;19,08
22,44;14,52;28,08;8,04
47,40;49,32;6,96;12,96
90,60;12,96;7,20;11,88
20,64;4,92;37,92;7,08
200,16;50,40;4,32;23,52
179,64;42,72;7,20;20,76
45,84;4,44;16,56;9,12
113,04;5,88;9,72;11,64
212,40;11,16;7,68;15,36
340,32;50,40;79,44;30,60
278,52;10,32;10,44;16,08
7 changes: 7 additions & 0 deletions examples/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,13 @@ @article{johnson1999
title = {The Insignificance of Statistical Significance Testing},
journal = {The Journal of Wildlife Management}
}
@manual{kassambara2019datarium,
title = {datarium: Data Bank for Statistical Analysis and Visualization},
author = {Alboukadel Kassambara},
year = {2019},
note = {R package version 0.1.0},
url = {https://CRAN.R-project.org/package=datarium}
}
@misc{kingma2014autoencoding,
title = {Auto-Encoding Variational Bayes},
author = {Diederik P Kingma and Max Welling},
Expand Down