From 2c00358a9a1451eb60c916c90b70498cc562289c Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sun, 16 Feb 2025 17:35:22 -0500 Subject: [PATCH 1/8] Add YML env files --- env-dev.yml | 23 +++++++++++++++++++++++ env.yml | 14 ++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 env-dev.yml create mode 100644 env.yml diff --git a/env-dev.yml b/env-dev.yml new file mode 100644 index 0000000..1e28429 --- /dev/null +++ b/env-dev.yml @@ -0,0 +1,23 @@ +name: pymc-bart-dev +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.16.2,<=5.19.1 + - arviz>=0.18.0 + - numba + - matplotlib + - numpy + - pytensor + # Development dependencies + - pytest>=4.4.0 + - pytest-cov>=2.6.1 + - click==8.0.4 + - pylint==2.17.4 + - pre-commit + - black + - isort + - flake8 + - pip + - pip: + - -e . diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..bd814ae --- /dev/null +++ b/env.yml @@ -0,0 +1,14 @@ +name: pymc-bart +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.16.2,<=5.19.1 + - arviz>=0.18.0 + - numba + - matplotlib + - numpy + - pytensor + - pip + - pip: + - pymc-bart From d41e23942d4f1130ad09da450cfbaa6e7950bbd7 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sun, 16 Feb 2025 17:35:54 -0500 Subject: [PATCH 2/8] Expand scatter_submodels to categorical likelihood --- pymc_bart/utils.py | 83 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 25 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index df8f76f..f08c2c4 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1125,8 +1125,11 @@ def plot_scatter_submodels( plot_kwargs : dict Additional keyword arguments for the plot. Defaults to None. Valid keys are: - - color_ref: matplotlib valid color for the 45 degree line + - marker_scatter: matplotlib valid marker for the scatter plot - color_scatter: matplotlib valid color for the scatter plot + - alpha_scatter: matplotlib valid alpha for the scatter plot + - color_ref: matplotlib valid color for the 45 degree line + - ls_ref: matplotlib valid linestyle for the reference line axes : axes Matplotlib axes. @@ -1140,41 +1143,71 @@ def plot_scatter_submodels( submodels = np.sort(submodels) indices = vi_results["indices"][submodels] - preds = vi_results["preds"][submodels] + preds_sub = vi_results["preds"][submodels] preds_all = vi_results["preds_all"] + if labels is None: + labels = vi_results["labels"][submodels] + + # handle categorical regression case: + n_cats = None + if preds_all.ndim > 2: + n_cats = preds_all.shape[-1] + indices = np.tile(indices, n_cats) + # labels = np.tile(labels, n_cats) + # cats = np.repeat(np.arange(n_cats), len(indices) // n_cats) + if ax is None: _, ax = _get_axes(grid, len(indices), True, True, figsize) if plot_kwargs is None: plot_kwargs = {} - if labels is None: - labels = vi_results["labels"][submodels] - if func is not None: - preds = func(preds) + preds_sub = func(preds_sub) preds_all = func(preds_all) - min_ = min(np.min(preds), np.min(preds_all)) - max_ = max(np.max(preds), np.max(preds_all)) - - for pred, x_label, axi in zip(preds, labels, ax.ravel()): - axi.plot( - pred, - preds_all, - marker=plot_kwargs.get("marker_scatter", "."), - ls="", - color=plot_kwargs.get("color_scatter", "C0"), - alpha=plot_kwargs.get("alpha_scatter", 0.1), - ) - axi.set_xlabel(x_label) - axi.axline( - [min_, min_], - [max_, max_], - color=plot_kwargs.get("color_ref", "0.5"), - ls=plot_kwargs.get("ls_ref", "--"), - ) + min_ = min(np.min(preds_sub), np.min(preds_all)) + max_ = max(np.max(preds_sub), np.max(preds_all)) + + # handle categorical regression case: + if n_cats is not None: + i = 0 + for cat in range(n_cats): + for pred_sub, x_label in zip(preds_sub, labels): + ax[i].plot( + pred_sub[..., cat], + preds_all[..., cat], + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", f"C{cat}"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}") + ax[i].axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) + i += 1 + else: + for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()): + axi.plot( + pred_sub, + preds_all, + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", "C0"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + axi.set(xlabel=x_label, ylabel="ref model") + axi.axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) return ax From 957b0ac39c90a3ff7131b9729399dac675869bb6 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sun, 16 Feb 2025 18:01:14 -0500 Subject: [PATCH 3/8] Add softmax option to plot_pdp --- pymc_bart/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index f08c2c4..3b9ec64 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -290,6 +290,7 @@ def plot_pdp( var_idx: Optional[list[int]] = None, var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, + softmax_link: Optional[bool] = False, samples: int = 200, ref_line: bool = True, random_seed: Optional[int] = None, @@ -330,6 +331,9 @@ def plot_pdp( List of the indices of the covariate treated as discrete. func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. + softmax_link: Optional[bool] = False, + If True the predictions are transformed using the softmax function. Only works when + likelihood is categorical. Defaults to False. samples : int Number of posterior samples used in the predictions. Defaults to 200 ref_line : bool @@ -396,6 +400,12 @@ def identity(x): p_d = _sample_posterior( all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape ) + if softmax_link is True: + from scipy.special import softmax + + # categories are the last dimension + p_d = softmax(p_d, axis=-1) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="hdi currently interprets 2d data") new_x = fake_X[:, var] From 733e66b560eb31ba4f565235e031d87522d543ca Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sun, 16 Feb 2025 18:19:36 -0500 Subject: [PATCH 4/8] Remove comments --- pymc_bart/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 3b9ec64..928b247 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1164,8 +1164,6 @@ def plot_scatter_submodels( if preds_all.ndim > 2: n_cats = preds_all.shape[-1] indices = np.tile(indices, n_cats) - # labels = np.tile(labels, n_cats) - # cats = np.repeat(np.arange(n_cats), len(indices) // n_cats) if ax is None: _, ax = _get_axes(grid, len(indices), True, True, figsize) From 88847b67dd677e1490132760a59eef548b189285 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Mon, 17 Feb 2025 17:13:10 -0500 Subject: [PATCH 5/8] Use func for softmax --- pymc_bart/utils.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 928b247..eb54e25 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -290,7 +290,6 @@ def plot_pdp( var_idx: Optional[list[int]] = None, var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, - softmax_link: Optional[bool] = False, samples: int = 200, ref_line: bool = True, random_seed: Optional[int] = None, @@ -331,9 +330,6 @@ def plot_pdp( List of the indices of the covariate treated as discrete. func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. - softmax_link: Optional[bool] = False, - If True the predictions are transformed using the softmax function. Only works when - likelihood is categorical. Defaults to False. samples : int Number of posterior samples used in the predictions. Defaults to 200 ref_line : bool @@ -400,17 +396,18 @@ def identity(x): p_d = _sample_posterior( all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape ) - if softmax_link is True: - from scipy.special import softmax - - # categories are the last dimension - p_d = softmax(p_d, axis=-1) + # need to apply func to full array and to last dimension if it's softmax + if func.__name__ == "softmax": + # categories are always the last dimension + # for some reason, mypy thinks that func can be identity, + # which doesn't have the axis argument + p_d = func(p_d, axis=-1) # type: ignore[call-arg] with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="hdi currently interprets 2d data") new_x = fake_X[:, var] for s_i in range(shape): - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] if func.__name__ == "softmax" else func(p_d[:, :, s_i]) null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) From 444f73eefc0ef90070330eee0e124d7e00f05930 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sat, 22 Feb 2025 20:27:04 -0500 Subject: [PATCH 6/8] handle func upstream --- pymc_bart/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index eb54e25..cbc9d0c 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -393,21 +393,17 @@ def identity(x): for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) - p_d = _sample_posterior( - all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + p_d = func( + _sample_posterior( + all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + ) ) - # need to apply func to full array and to last dimension if it's softmax - if func.__name__ == "softmax": - # categories are always the last dimension - # for some reason, mypy thinks that func can be identity, - # which doesn't have the axis argument - p_d = func(p_d, axis=-1) # type: ignore[call-arg] with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="hdi currently interprets 2d data") new_x = fake_X[:, var] for s_i in range(shape): - p_di = p_d[:, :, s_i] if func.__name__ == "softmax" else func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) From 90862638b9025905cdf98b0437eb4d7082e91d8a Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 10 Mar 2025 08:56:38 +0200 Subject: [PATCH 7/8] move func upstream --- pymc_bart/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index cbc9d0c..9900576 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -254,13 +254,13 @@ def identity(x): ) new_x = fake_X[:, var] - p_d = np.array(y_pred) + p_d = func(y_pred) for s_i in range(shape): if centered: - p_di = func(p_d[:, :, s_i]) - func(p_d[:, :, s_i][:, 0][:, None]) + p_di = p_d[:, :, s_i] - p_d[:, :, s_i][:, 0][:, None] else: - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] if var in var_discrete: axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean) axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha) From abe70aa1cd676a012b8c3f8c89a3819326f7d89e Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 10 Mar 2025 09:15:17 +0200 Subject: [PATCH 8/8] ensure p_d is an array --- pymc_bart/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 9900576..3ba6e58 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -254,7 +254,7 @@ def identity(x): ) new_x = fake_X[:, var] - p_d = func(y_pred) + p_d = func(np.array(y_pred)) for s_i in range(shape): if centered: