Skip to content

Commit 1f9ec4d

Browse files
authored
BART: Categorical example (#663)
* pymc-bart categorical example * modifications from comments * typo * suggested changes * typo * metadata * typo metadata
1 parent 8162c95 commit 1f9ec4d

File tree

4 files changed

+3758
-0
lines changed

4 files changed

+3758
-0
lines changed

examples/bart/bart_categorical_hawks.ipynb

Lines changed: 2575 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
kernelspec:
8+
display_name: Python 3 (ipykernel)
9+
language: python
10+
name: python3
11+
myst:
12+
substitutions:
13+
conda_dependencies: pymc-bart
14+
pip_dependencies: pymc-bart
15+
---
16+
17+
+++ {"editable": true, "slideshow": {"slide_type": ""}}
18+
19+
(bart_categorical)=
20+
# Categorical regression
21+
22+
:::{post} May, 2024
23+
:tags: BART, regression
24+
:category: beginner, reference
25+
:author: Pablo Garay, Osvaldo Martin
26+
:::
27+
28+
+++
29+
30+
In this example, we will model outcomes with more than two categories.
31+
:::{include} ../extra_installs.md
32+
:::
33+
34+
```{code-cell} ipython3
35+
import os
36+
import warnings
37+
38+
import arviz as az
39+
import matplotlib.pyplot as plt
40+
import numpy as np
41+
import pandas as pd
42+
import pymc as pm
43+
import pymc_bart as pmb
44+
import seaborn as sns
45+
46+
warnings.simplefilter(action="ignore", category=FutureWarning)
47+
```
48+
49+
```{code-cell} ipython3
50+
# set formats
51+
RANDOM_SEED = 8457
52+
az.style.use("arviz-darkgrid")
53+
```
54+
55+
## Hawks dataset
56+
57+
Here we will use a dataset that contains information about 3 species of hawks (*CH*=Cooper's, *RT*=Red-tailed, *SS*=Sharp-Shinned). This dataset has information for 908 individuals in total, each one containing 16 variables, in addition to the species. To simplify the example, we will use the following 5 covariables:
58+
- *Wing*: Length (in mm) of primary wing feather from tip to wrist it attaches to.
59+
- *Weight*: Body weight (in gr).
60+
- *Culmen*: Length (in mm) of the upper bill from the tip to where it bumps into the fleshy part of the bird.
61+
- *Hallux*: Length (in mm) of the killing talon.
62+
- *Tail*: Measurement (in mm) related to the length of the tail.
63+
64+
Also we are going to eliminate the NaNs in the dataset. With these we will predict the "Species" of hawks, in other words, these are our dependent variables, the classes we want to predict.
65+
66+
```{code-cell} ipython3
67+
# Load data and eliminate NANs
68+
try:
69+
Hawks = pd.read_csv(os.path.join("..", "data", "Hawks.csv"))[
70+
["Wing", "Weight", "Culmen", "Hallux", "Tail", "Species"]
71+
].dropna()
72+
except FileNotFoundError:
73+
Hawks = pd.read_csv(pm.get_data("Hawks.csv"))[
74+
["Wing", "Weight", "Culmen", "Hallux", "Tail", "Species"]
75+
].dropna()
76+
77+
Hawks.head()
78+
```
79+
80+
## EDA
81+
The following compares covariables to get a rapid data visualization for the 3 species.
82+
83+
```{code-cell} ipython3
84+
sns.pairplot(Hawks, hue="Species");
85+
```
86+
87+
It can be seen that the RT species have distributions more differentiated than the other two in almost all covariables, and the covariables wing, weight, and culmen present certain separations between species. Still, none of the variables have a marked separation among the species distributions such that they can cleanly separate them. It is possible to make a combination of covariables, probably wing, weight, and culmen, to achieve the classification. These are the principal reasons for realizing the regression.
88+
89+
+++
90+
91+
## Model Specification
92+
93+
First, we are going to prepare the data for the model, using "Species" as response and, "Wing", "Weight", "Culmen", "Hallux" and "Tail" as predictors. Using `pd.Categorical(Hawks['Species']).codes` we can codify the name species into integers between 0 and 2, being 0="CH", 1="RT" and 2="SS".
94+
95+
```{code-cell} ipython3
96+
y_0 = pd.Categorical(Hawks["Species"]).codes
97+
x_0 = Hawks[["Wing", "Weight", "Culmen", "Hallux", "Tail"]]
98+
print(len(x_0), x_0.shape, y_0.shape)
99+
```
100+
101+
We only can have an instance of {class}`~pymc_bart.BART()` in each pymc model (for now), so to model 3 species we can use coordinate and dimension names to specify the shapes of variables, *indicating* that there are 891 rows of information for 3 species. This step facilite the later selection of groups from the `InferenceData`.
102+
103+
```{code-cell} ipython3
104+
_, species = pd.factorize(Hawks["Species"], sort=True)
105+
species
106+
```
107+
108+
```{code-cell} ipython3
109+
coords = {"n_obs": np.arange(len(x_0)), "species": species}
110+
```
111+
112+
In this model we use the `pm.math.softmax()` function, for $\mu$ from `pmb.BART()`, because guarantees that the vector sums to 1 along the `axis=0` in this case.
113+
114+
```{code-cell} ipython3
115+
with pm.Model(coords=coords) as model_hawks:
116+
μ = pmb.BART("μ", x_0, y_0, m=50, dims=["species", "n_obs"])
117+
θ = pm.Deterministic("θ", pm.math.softmax(μ, axis=0))
118+
y = pm.Categorical("y", p=θ.T, observed=y_0)
119+
120+
pm.model_to_graphviz(model=model_hawks)
121+
```
122+
123+
Now fit the model and get samples from the posterior.
124+
125+
```{code-cell} ipython3
126+
with model_hawks:
127+
idata = pm.sample(chains=4, compute_convergence_checks=False, random_seed=123)
128+
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
129+
```
130+
131+
## Results
132+
133+
### Variable Importance
134+
135+
It may be that some of the input variables are not informative for classifying by species, so in the interest of parsimony and in reducing the computational cost of model estimation, it is useful to quantify the importance of each variable in the dataset. PyMC-BART provides the function {func}`~pymc_bart.plot_variable_importance()`, which generates a plot that shows on his x-axis the number of covariables and on the y-axis the R$^2$ (the square of the Pearson correlation coefficient) between the predictions made for the full model (all variables included) and the restricted models, those with only a subset of the variables. The error bars represent the 94 % HDI from the posterior predictive distribution.
136+
137+
```{code-cell} ipython3
138+
---
139+
editable: true
140+
slideshow:
141+
slide_type: ''
142+
---
143+
pmb.plot_variable_importance(idata, μ, x_0, method="VI", random_seed=RANDOM_SEED);
144+
```
145+
146+
It can be observed that with the covariables `Hallux`, `Culmen`, and `Wing` we achieve the same R$^2$ value that we obtained with all the covariables, this is that the last two covariables contribute less than the other three to the classification. One thing we have to take into account in this is that the HDI is quite wide, which gives us less precision on the results, later we are going to see a way to reduce this.
147+
148+
+++
149+
150+
### Partial Dependence Plot
151+
152+
Let's check the behavior of each covariable for each species with `pmb.plot_pdp()`, which shows the marginal effect a covariate has on the predicted variable, while we average over all the other covariates.
153+
154+
```{code-cell} ipython3
155+
pmb.plot_pdp(μ, X=x_0, Y=y_0, grid=(5, 3), figsize=(6, 9));
156+
```
157+
158+
The pdp plot, together with the Variable Importance plot, confirms that `Tail` is the covariable with the smaller effect over the predicted variable. In the Variable Importance plot `Tail` is the last covariable to be added and does not improve the result, in the pdp plot `Tail` has the flattest response. For the rest of the covariables in this plot, it's hard to see which of them have more effect over the predicted variable, because they have great variability, showed in the HDI wide, same as before later we are going to see a way to reduce this variability. Finally, some variability depends on the amount of data for each species, which we can see in the `counts` from one of the covariables using Pandas `.describe()` and grouping the data from "Species" with `.groupby("Species")`.
159+
160+
+++
161+
162+
### Predicted vs Observed
163+
164+
Now we are going to compare the predicted data with the observed data to evaluate the fit of the model, we do this with the Arviz function `az.plot_ppc()`.
165+
166+
```{code-cell} ipython3
167+
ax = az.plot_ppc(idata, kind="kde", num_pp_samples=200, random_seed=123)
168+
# plot aesthetics
169+
ax.set_ylim(0, 0.7)
170+
ax.set_yticks([0, 0.2, 0.4, 0.6])
171+
ax.set_ylabel("Probability")
172+
ax.set_xticks([0.5, 1.5, 2.5])
173+
ax.set_xticklabels(["CH", "RT", "SS"])
174+
ax.set_xlabel("Species");
175+
```
176+
177+
We can see a good agreement between the observed data (black line) and those predicted by the model (blue and orange lines). As we mentioned before, the difference in the values between species is influenced by the amount of data for each one. Here there is no observed dispersion in the predicted data as we saw in the previous two plots.
178+
179+
+++
180+
181+
Below we see that the in-sample predictions provide very good agreement with the observations.
182+
183+
```{code-cell} ipython3
184+
np.mean((idata.posterior_predictive["y"] - y_0) == 0) * 100
185+
```
186+
187+
```{code-cell} ipython3
188+
all = 0
189+
for i in range(3):
190+
perct_per_class = np.mean(idata.posterior_predictive["y"].where(y_0 == i) == i) * 100
191+
all += perct_per_class
192+
print(perct_per_class)
193+
all
194+
```
195+
196+
So far we have a very good result concerning the classification of the species based on the 5 covariables. However, if we want to select a subset of covariable to perform future classifications is not very clear which of them to select. Maybe something sure is that `Tail` could be eliminated. At the beginning when we plot the distribution of each covariable we said that the most important variables to make the classification could be `Wing`, `Weight` and, `Culmen`, nevertheless after running the model we saw that `Hallux`, `Culmen` and, `Wing`, proved to be the most important ones.
197+
198+
Unfortunatelly, the partial dependence plots show a very wide dispersion, making results look suspicious. One way to reduce this variability is adjusting independent trees, below we will see how to do this and get a more accurate result.
199+
200+
+++
201+
202+
## Fitting independent trees
203+
204+
The option to fit independent trees with `pymc-bart` is set with the parameter `pmb.BART(..., separate_trees=True, ...)`. As we will see, for this example, using this option doesn't give a big difference in the predictions, but helps us to reduce the variability in the ppc and get a small improvement in the in-sample comparison. In case this option is used with bigger datasets you have to take into account that the model fits more slowly, so you can obtain a better result at the expense of computational cost. The following code runs the same model and analysis as before, but fitting independent trees. Compare the time to run this model with the previous one.
205+
206+
```{code-cell} ipython3
207+
with pm.Model(coords=coords) as model_t:
208+
μ_t = pmb.BART("μ", x_0, y_0, m=50, separate_trees=True, dims=["species", "n_obs"])
209+
θ_t = pm.Deterministic("θ", pm.math.softmax(μ_t, axis=0))
210+
y_t = pm.Categorical("y", p=θ_t.T, observed=y_0)
211+
idata_t = pm.sample(chains=4, compute_convergence_checks=False, random_seed=123)
212+
pm.sample_posterior_predictive(idata_t, extend_inferencedata=True)
213+
```
214+
215+
Now we are going to reproduce the same analyses as before.
216+
217+
```{code-cell} ipython3
218+
pmb.plot_variable_importance(idata_t, μ_t, x_0, method="VI", random_seed=RANDOM_SEED);
219+
```
220+
221+
```{code-cell} ipython3
222+
pmb.plot_pdp(μ_t, X=x_0, Y=y_0, grid=(5, 3), figsize=(6, 9));
223+
```
224+
225+
Comparing these two plots with the previous ones shows a marked reduction in the variance for each one. In the case of `pmb.plot_variable_importance()` there are smallers error bands with an R$^{2}$ value more close to 1. And for `pm.plot_pdp()` we can see thinner bands and a reduction in the limits on the y-axis, this is a representation of the reduction of the uncertainty due to adjusting the trees separately. Another benefit of this is that is more visible the behavior of each covariable for each one of the species.
226+
227+
With all these together, we can select `Hallux`, `Culmen`, and, `Wing` as covariables to make the classification.
228+
229+
+++
230+
231+
Concerning the comparison between observed and predicted data, we obtain the same good result with less uncertainty for the predicted values (blue lines). And the same counts for the in-sample comparison.
232+
233+
```{code-cell} ipython3
234+
ax = az.plot_ppc(idata_t, kind="kde", num_pp_samples=100, random_seed=123)
235+
ax.set_ylim(0, 0.7)
236+
ax.set_yticks([0, 0.2, 0.4, 0.6])
237+
ax.set_ylabel("Probability")
238+
ax.set_xticks([0.5, 1.5, 2.5])
239+
ax.set_xticklabels(["CH", "RT", "SS"])
240+
ax.set_xlabel("Species");
241+
```
242+
243+
```{code-cell} ipython3
244+
np.mean((idata_t.posterior_predictive["y"] - y_0) == 0) * 100
245+
```
246+
247+
```{code-cell} ipython3
248+
all = 0
249+
for i in range(3):
250+
perct_per_class = np.mean(idata_t.posterior_predictive["y"].where(y_0 == i) == i) * 100
251+
all += perct_per_class
252+
print(perct_per_class)
253+
all
254+
```
255+
256+
## Authors
257+
- Authored by [Pablo Garay](https://github.com/PabloGGaray) and [Osvaldo Martin](https://aloctavodia.github.io/) in May, 2024
258+
259+
+++
260+
261+
## References
262+
:::{bibliography} :filter: docname in docnames :::
263+
264+
+++
265+
266+
## Watermark
267+
268+
```{code-cell} ipython3
269+
%load_ext watermark
270+
%watermark -n -u -v -iv -w -p pytensor
271+
```
272+
273+
:::{include} ../page_footer.md :::

examples/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def setup(app: Sphinx):
181181
"numpy": ("https://numpy.org/doc/stable/", None),
182182
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
183183
"pymc": ("https://www.pymc.io/projects/docs/en/stable/", None),
184+
"pymc-bart": ("https://www.pymc.io/projects/bart/en/latest/", None),
184185
"pytensor": ("https://pytensor.readthedocs.io/en/latest/", None),
185186
"pmx": ("https://www.pymc.io/projects/experimental/en/latest/", None),
186187
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),

0 commit comments

Comments
 (0)