diff --git a/env-dev.yml b/env-dev.yml new file mode 100644 index 0000000..1e28429 --- /dev/null +++ b/env-dev.yml @@ -0,0 +1,23 @@ +name: pymc-bart-dev +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.16.2,<=5.19.1 + - arviz>=0.18.0 + - numba + - matplotlib + - numpy + - pytensor + # Development dependencies + - pytest>=4.4.0 + - pytest-cov>=2.6.1 + - click==8.0.4 + - pylint==2.17.4 + - pre-commit + - black + - isort + - flake8 + - pip + - pip: + - -e . diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..bd814ae --- /dev/null +++ b/env.yml @@ -0,0 +1,14 @@ +name: pymc-bart +channels: + - conda-forge + - defaults +dependencies: + - pymc>=5.16.2,<=5.19.1 + - arviz>=0.18.0 + - numba + - matplotlib + - numpy + - pytensor + - pip + - pip: + - pymc-bart diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index df8f76f..3ba6e58 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -254,13 +254,13 @@ def identity(x): ) new_x = fake_X[:, var] - p_d = np.array(y_pred) + p_d = func(np.array(y_pred)) for s_i in range(shape): if centered: - p_di = func(p_d[:, :, s_i]) - func(p_d[:, :, s_i][:, 0][:, None]) + p_di = p_d[:, :, s_i] - p_d[:, :, s_i][:, 0][:, None] else: - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] if var in var_discrete: axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean) axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha) @@ -393,14 +393,17 @@ def identity(x): for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) - p_d = _sample_posterior( - all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + p_d = func( + _sample_posterior( + all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape + ) ) + with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="hdi currently interprets 2d data") new_x = fake_X[:, var] for s_i in range(shape): - p_di = func(p_d[:, :, s_i]) + p_di = p_d[:, :, s_i] null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) @@ -1125,8 +1128,11 @@ def plot_scatter_submodels( plot_kwargs : dict Additional keyword arguments for the plot. Defaults to None. Valid keys are: - - color_ref: matplotlib valid color for the 45 degree line + - marker_scatter: matplotlib valid marker for the scatter plot - color_scatter: matplotlib valid color for the scatter plot + - alpha_scatter: matplotlib valid alpha for the scatter plot + - color_ref: matplotlib valid color for the 45 degree line + - ls_ref: matplotlib valid linestyle for the reference line axes : axes Matplotlib axes. @@ -1140,41 +1146,69 @@ def plot_scatter_submodels( submodels = np.sort(submodels) indices = vi_results["indices"][submodels] - preds = vi_results["preds"][submodels] + preds_sub = vi_results["preds"][submodels] preds_all = vi_results["preds_all"] + if labels is None: + labels = vi_results["labels"][submodels] + + # handle categorical regression case: + n_cats = None + if preds_all.ndim > 2: + n_cats = preds_all.shape[-1] + indices = np.tile(indices, n_cats) + if ax is None: _, ax = _get_axes(grid, len(indices), True, True, figsize) if plot_kwargs is None: plot_kwargs = {} - if labels is None: - labels = vi_results["labels"][submodels] - if func is not None: - preds = func(preds) + preds_sub = func(preds_sub) preds_all = func(preds_all) - min_ = min(np.min(preds), np.min(preds_all)) - max_ = max(np.max(preds), np.max(preds_all)) - - for pred, x_label, axi in zip(preds, labels, ax.ravel()): - axi.plot( - pred, - preds_all, - marker=plot_kwargs.get("marker_scatter", "."), - ls="", - color=plot_kwargs.get("color_scatter", "C0"), - alpha=plot_kwargs.get("alpha_scatter", 0.1), - ) - axi.set_xlabel(x_label) - axi.axline( - [min_, min_], - [max_, max_], - color=plot_kwargs.get("color_ref", "0.5"), - ls=plot_kwargs.get("ls_ref", "--"), - ) + min_ = min(np.min(preds_sub), np.min(preds_all)) + max_ = max(np.max(preds_sub), np.max(preds_all)) + + # handle categorical regression case: + if n_cats is not None: + i = 0 + for cat in range(n_cats): + for pred_sub, x_label in zip(preds_sub, labels): + ax[i].plot( + pred_sub[..., cat], + preds_all[..., cat], + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", f"C{cat}"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}") + ax[i].axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) + i += 1 + else: + for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()): + axi.plot( + pred_sub, + preds_all, + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", "C0"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + axi.set(xlabel=x_label, ylabel="ref model") + axi.axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) return ax