Skip to content

BART-VI use mean of leaf nodes when prunning instead of zero #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions pymc_experimental/bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ class Tree:
def __init__(self, num_observations=0, shape=1):
self.tree_structure = {}
self.idx_leaf_nodes = []
self.shape = shape
self.output = (
np.zeros((num_observations, self.shape)).astype(aesara.config.floatX).squeeze()
)
self.output = np.zeros((num_observations, shape)).astype(aesara.config.floatX).squeeze()

def __getitem__(self, index):
return self.get_node(index)
Expand Down Expand Up @@ -93,29 +90,30 @@ def _predict(self):
output[leaf_node.idx_data_points] = leaf_node.value
return output.T

def predict(self, X, excluded=None):
def predict(self, x, excluded=None):
"""
Predict output of tree for an (un)observed point X.
Predict output of tree for an (un)observed point x.

Parameters
----------
X : numpy array
x : numpy array
Unobserved point

Returns
-------
float
Value of the leaf value where the unobserved point lies.
"""
leaf_node = self._traverse_tree(X, node_index=0)
leaf_value = leaf_node.value
if excluded is not None:
parent_node = leaf_node.get_idx_parent_node()
if self.get_node(parent_node).idx_split_variable in excluded:
leaf_value = np.zeros(self.shape)
if excluded is None:
excluded = []
node = self._traverse_tree(x, 0, excluded)
if isinstance(node, LeafNode):
leaf_value = node.value
else:
leaf_value = node
return leaf_value

def _traverse_tree(self, x, node_index=0):
def _traverse_tree(self, x, node_index, excluded):
"""
Traverse the tree starting from a particular node given an unobserved point.

Expand All @@ -126,18 +124,44 @@ def _traverse_tree(self, x, node_index=0):

Returns
-------
LeafNode
LeafNode or mean of leaf node values
"""
current_node = self.get_node(node_index)
if isinstance(current_node, SplitNode):
if current_node.idx_split_variable in excluded:
leaf_values = []
self._traverse_leaf_values(leaf_values, node_index)
return np.mean(leaf_values, 0)

if x[current_node.idx_split_variable] <= current_node.split_value:
left_child = current_node.get_idx_left_child()
current_node = self._traverse_tree(x, left_child)
current_node = self._traverse_tree(x, left_child, excluded)
else:
right_child = current_node.get_idx_right_child()
current_node = self._traverse_tree(x, right_child)
current_node = self._traverse_tree(x, right_child, excluded)
return current_node

def _traverse_leaf_values(self, leaf_values, node_index):
"""
Traverse the tree appending leaf values starting from a particular node.

Parameters
----------
node_index : int

Returns
-------
List of leaf node values
"""
current_node = self.get_node(node_index)
if isinstance(current_node, SplitNode):
left_child = current_node.get_idx_left_child()
self._traverse_leaf_values(leaf_values, left_child)
right_child = current_node.get_idx_right_child()
self._traverse_leaf_values(leaf_values, right_child)
else:
leaf_values.append(current_node.value)

@staticmethod
def init_tree(leaf_node_value, idx_data_points, shape):
"""
Expand Down
34 changes: 23 additions & 11 deletions pymc_experimental/bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def plot_dependence(
return axes


def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, random_seed=None):
def plot_variable_importance(
idata, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None
):
"""
Estimates variable importance from the BART-posterior.

Expand All @@ -319,9 +321,11 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
InferenceData containing a collection of BART_trees in sample_stats group
X : array-like
The covariate matrix.
labels: list
labels : list
List of the names of the covariates. If X is a DataFrame the names of the covariables will
be taken from it and this argument will be ignored.
sort_vars : bool
Whether to sort the variables according to their variable importance. Defaults to True.
figsize : tuple
Figure size. If None it will be defined automatically.
samples : int
Expand All @@ -337,23 +341,29 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
_, axes = plt.subplots(2, 1, figsize=figsize)

if hasattr(X, "columns") and hasattr(X, "values"):
labels = list(X.columns)
labels = X.columns
X = X.values

VI = idata.sample_stats["variable_inclusion"].mean(("chain", "draw")).values
if labels is None:
labels = range(len(VI))
labels = np.arange(len(VI))
else:
labels = np.array(labels)

ticks = np.arange(len(VI), dtype=int)
idxs = np.argsort(VI)
subsets = [idxs[:-i] for i in range(1, len(idxs))]
subsets.append(None)

axes[0].plot(VI / VI.sum(), "o-")
if sort_vars:
indices = idxs[::-1]
else:
indices = np.arange(len(VI))
axes[0].plot((VI / VI.sum())[indices], "o-")
axes[0].set_xticks(ticks)
axes[0].set_xticklabels(labels)
axes[0].set_xlabel("variable index")
axes[0].set_ylabel("relative importance")
axes[0].set_xticklabels(labels[indices])
axes[0].set_xlabel("covariables")
axes[0].set_ylabel("importance")

predicted_all = predict(idata, rng, X=X, size=samples, excluded=None)

Expand All @@ -363,16 +373,18 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
predicted_subset = predict(idata, rng, X=X, size=samples, excluded=subset)
pearson = np.zeros(samples)
for j in range(samples):
pearson[j] = pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
pearson[j] = (
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
) ** 2
EV_mean[idx] = np.mean(pearson)
EV_hdi[idx] = az.hdi(pearson)

axes[1].errorbar(ticks, EV_mean, np.array((EV_mean - EV_hdi[:, 0], EV_hdi[:, 1] - EV_mean)))

axes[1].set_xticks(ticks)
axes[1].set_xticklabels(ticks + 1)
axes[1].set_xlabel("number of components")
axes[1].set_ylabel("correlation")
axes[1].set_xlabel("number of covariables")
axes[1].set_ylabel("R²", rotation=0, labelpad=12)
axes[1].set_ylim(0, 1)

axes[0].set_xlim(-0.5, len(VI) - 0.5)
Expand Down