Skip to content

Enhance plot_pdp and fix plot_scatter_submodels #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions env-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: pymc-bart-dev
channels:
- conda-forge
- defaults
dependencies:
- pymc>=5.16.2,<=5.19.1
- arviz>=0.18.0
- numba
- matplotlib
- numpy
- pytensor
# Development dependencies
- pytest>=4.4.0
- pytest-cov>=2.6.1
- click==8.0.4
- pylint==2.17.4
- pre-commit
- black
- isort
- flake8
- pip
- pip:
- -e .
14 changes: 14 additions & 0 deletions env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: pymc-bart
channels:
- conda-forge
- defaults
dependencies:
- pymc>=5.16.2,<=5.19.1
- arviz>=0.18.0
- numba
- matplotlib
- numpy
- pytensor
- pip
- pip:
- pymc-bart
96 changes: 65 additions & 31 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,13 @@ def identity(x):
)

new_x = fake_X[:, var]
p_d = np.array(y_pred)
p_d = func(np.array(y_pred))

for s_i in range(shape):
if centered:
p_di = func(p_d[:, :, s_i]) - func(p_d[:, :, s_i][:, 0][:, None])
p_di = p_d[:, :, s_i] - p_d[:, :, s_i][:, 0][:, None]
else:
p_di = func(p_d[:, :, s_i])
p_di = p_d[:, :, s_i]
if var in var_discrete:
axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean)
axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha)
Expand Down Expand Up @@ -393,14 +393,17 @@ def identity(x):
for var in range(len(var_idx)):
excluded = indices[:]
excluded.remove(var)
p_d = _sample_posterior(
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
p_d = func(
_sample_posterior(
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
)
)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
new_x = fake_X[:, var]
for s_i in range(shape):
p_di = func(p_d[:, :, s_i])
p_di = p_d[:, :, s_i]
null_pd.append(p_di.mean())
if var in var_discrete:
_, idx_uni = np.unique(new_x, return_index=True)
Expand Down Expand Up @@ -1125,8 +1128,11 @@ def plot_scatter_submodels(
plot_kwargs : dict
Additional keyword arguments for the plot. Defaults to None.
Valid keys are:
- color_ref: matplotlib valid color for the 45 degree line
- marker_scatter: matplotlib valid marker for the scatter plot
- color_scatter: matplotlib valid color for the scatter plot
- alpha_scatter: matplotlib valid alpha for the scatter plot
- color_ref: matplotlib valid color for the 45 degree line
- ls_ref: matplotlib valid linestyle for the reference line
axes : axes
Matplotlib axes.

Expand All @@ -1140,41 +1146,69 @@ def plot_scatter_submodels(
submodels = np.sort(submodels)

indices = vi_results["indices"][submodels]
preds = vi_results["preds"][submodels]
preds_sub = vi_results["preds"][submodels]
preds_all = vi_results["preds_all"]

if labels is None:
labels = vi_results["labels"][submodels]

# handle categorical regression case:
n_cats = None
if preds_all.ndim > 2:
n_cats = preds_all.shape[-1]
indices = np.tile(indices, n_cats)

if ax is None:
_, ax = _get_axes(grid, len(indices), True, True, figsize)

if plot_kwargs is None:
plot_kwargs = {}

if labels is None:
labels = vi_results["labels"][submodels]

if func is not None:
preds = func(preds)
preds_sub = func(preds_sub)
preds_all = func(preds_all)

min_ = min(np.min(preds), np.min(preds_all))
max_ = max(np.max(preds), np.max(preds_all))

for pred, x_label, axi in zip(preds, labels, ax.ravel()):
axi.plot(
pred,
preds_all,
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", "C0"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
axi.set_xlabel(x_label)
axi.axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
min_ = min(np.min(preds_sub), np.min(preds_all))
max_ = max(np.max(preds_sub), np.max(preds_all))

# handle categorical regression case:
if n_cats is not None:
i = 0
for cat in range(n_cats):
for pred_sub, x_label in zip(preds_sub, labels):
ax[i].plot(
pred_sub[..., cat],
preds_all[..., cat],
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", f"C{cat}"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}")
ax[i].axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
i += 1
else:
for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()):
axi.plot(
pred_sub,
preds_all,
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", "C0"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
axi.set(xlabel=x_label, ylabel="ref model")
axi.axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
return ax


Expand Down