From 84bf94a24ee8af94b4ed335a2e5f4967276c22a3 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 17 Mar 2023 15:36:46 +0100 Subject: [PATCH] brute force copy-paste https://github.com/pymc-devs/pymc/pull/5044 --- pymc_bart/bart.py | 5 +++++ pymc_bart/pgbart.py | 6 ++++++ pymc_bart/tree.py | 36 ++++++++++++++++++++++++++++-------- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index dad53d6..fb4c9db 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -73,6 +73,9 @@ class BART(Distribution): alpha : float Control the prior probability over the depth of the trees. Even when it can takes values in the interval (0, 1), it is recommended to be in the interval (0, 0.5]. + response : str + How the leaf_node values are computed. Available options are ``constant``, ``linear`` or + ``mix`` (default). split_prior : array-like Each element of split_prior should be in the [0, 1] interval and the elements should sum to 1. Otherwise they will be normalized. @@ -86,6 +89,7 @@ def __new__( Y, m=50, alpha=0.25, + response="mix", split_prior=None, **kwargs, ): @@ -109,6 +113,7 @@ def __new__( Y=Y, m=m, alpha=alpha, + response=response, split_prior=split_prior, ), )() diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 6755836..4d75dcd 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -161,9 +161,11 @@ def astep(self, _): self.X, self.missing_data, self.sum_trees, + self.linear_fit, self.m, self.normal, self.shape, + self.response, ) if tree_grew: self.update_weight(p) @@ -315,9 +317,11 @@ def sample_tree( X, missing_data, sum_trees, + linear_fit, m, normal, shape, + response, ): tree_grew = False if self.expansion_nodes: @@ -334,10 +338,12 @@ def sample_tree( X, missing_data, sum_trees, + linear_fit, m, normal, self.kfactor, shape, + response, ) if idx_new_nodes is not None: self.expansion_nodes.extend(idx_new_nodes) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 4f98709..f4a58ec 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -113,7 +113,7 @@ def _predict(self): output[leaf_node.idx_data_points] = leaf_node.value return output.T - def predict(self, x, excluded=None): + def predict(self, x, m, excluded=None): """ Predict output of tree for an (un)observed point x. @@ -121,6 +121,8 @@ def predict(self, x, excluded=None): ---------- x : numpy array Unobserved point + m : int + Number of trees excluded: list Indexes of the variables to exclude when computing predictions @@ -131,7 +133,14 @@ def predict(self, x, excluded=None): """ if excluded is None: excluded = [] - return self._traverse_tree(x, 0, excluded) + + leaf_node, split_variable = self._traverse_tree(x, node_index=0, excluded=excluded) + if leaf_node.linear_params is None: + return self._traverse_tree(x, 0, excluded) + else: + x = x[split_variable].item() + y_x = leaf_node.linear_params[0] + leaf_node.linear_params[1] * x + return y_x / m def _traverse_tree(self, x, node_index, excluded): """ @@ -149,15 +158,16 @@ def _traverse_tree(self, x, node_index, excluded): current_node = self.get_node(node_index) if current_node.is_leaf_node(): return current_node.value - if current_node.idx_split_variable in excluded: + split_variable = current_node.idx_split_variable + if split_variable in excluded: leaf_values = [] self._traverse_leaf_values(leaf_values, node_index) return np.mean(leaf_values, 0) - if x[current_node.idx_split_variable] <= current_node.value: - next_node = current_node.get_idx_left_child() + if x[split_variable] <= current_node.value: + next_node = current_node.get_idx_left_child(), split_variable else: - next_node = current_node.get_idx_right_child() + next_node = current_node.get_idx_right_child(), split_variable return self._traverse_tree(x, next_node, excluded) def _traverse_leaf_values(self, leaf_values, node_index): @@ -181,13 +191,23 @@ def _traverse_leaf_values(self, leaf_values, node_index): class Node: - __slots__ = "index", "value", "idx_split_variable", "idx_data_points" + __slots__ = ( + "index", + "value", + "idx_split_variable", + "idx_data_points", + "idx_split_variable", + "linear_params", + ) - def __init__(self, index: int, value=-1, idx_data_points=None, idx_split_variable=-1): + def __init__( + self, index: int, value=-1, idx_data_points=None, idx_split_variable=-1, linear_params=None + ): self.index = index self.value = value self.idx_data_points = idx_data_points self.idx_split_variable = idx_split_variable + self.linear_params = linear_params @classmethod def new_leaf_node(cls, index: int, value, idx_data_points) -> "Node":