From fe38d539a203b1a8d5d455bf0fa7ae0df9cb3d35 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 3 Apr 2023 15:29:54 +0200 Subject: [PATCH 01/29] linear nodes init improve code logic add m param add m param to tests add signature in pbbart to add linear nodes --- pymc_bart/bart.py | 7 ++++++- pymc_bart/pgbart.py | 36 ++++++++++++++++++++++++++++++++++++ pymc_bart/tree.py | 32 +++++++++++++++++++++++++------- pymc_bart/utils.py | 17 ++++++++++++----- tests/test_bart.py | 8 ++++++-- 5 files changed, 85 insertions(+), 15 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index b7f688f..7839d85 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -52,7 +52,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, else: return np.full(cls.Y.shape[0], cls.Y.mean()) else: - return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze().T + return _sample_posterior(cls.all_trees, cls.X, cls.m, rng=rng).squeeze().T bart = BARTRV() @@ -72,6 +72,9 @@ class BART(Distribution): The response vector. m : int Number of trees + response : str + How the leaf_node values are computed. Available options are ``constant``, ``linear`` or + ``mix`` (default). alpha : float Control the prior probability over the depth of the trees. Even when it can takes values in the interval (0, 1), it is recommended to be in the interval (0, 0.5]. @@ -88,6 +91,7 @@ def __new__( Y: TensorLike, m: int = 50, alpha: float = 0.25, + response: str = "mix", split_prior: Optional[List[float]] = None, **kwargs, ): @@ -110,6 +114,7 @@ def __new__( X=X, Y=Y, m=m, + response=response, alpha=alpha, split_prior=split_prior, ), diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 81ea8b1..c2aac95 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -54,6 +54,8 @@ def sample_tree( missing_data, sum_trees, m, + linear_fit, + response, normal, shape, ) -> bool: @@ -73,6 +75,8 @@ def sample_tree( missing_data, sum_trees, m, + linear_fit, + response, normal, self.kfactor, shape, @@ -131,6 +135,7 @@ def __init__( self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m + self.response = self.bart.response shape = initial_values[value_bart.name].shape self.shape = 1 if len(shape) == 1 else shape[0] @@ -160,6 +165,9 @@ def __init__( num_observations=self.num_observations, shape=self.shape, ) + + self.linear_fit = fast_linear_fit() + self.normal = NormalSampler(mu_std, self.shape) self.uniform = UniformSampler(0, 1) self.uniform_kf = UniformSampler(0.33, 0.75, self.shape) @@ -209,6 +217,8 @@ def astep(self, _): self.missing_data, self.sum_trees, self.m, + self.linear_fit, + self.response, self.normal, self.shape, ): @@ -393,6 +403,8 @@ def grow_tree( missing_data, sum_trees, m, + linear_fit, + response, normal, kfactor, shape, @@ -597,3 +609,27 @@ def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin function = pytensor_function([inarray0], out_list[0]) function.trust_input = True return function + + +def fast_linear_fit(): + """If available use Numba to speed up the computation of the linear fit""" + + def linear_fit(X, Y): + + n = len(Y) + xbar = np.sum(X) / n + ybar = np.sum(Y) / n + + b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar**2) + a = ybar - b * xbar + + y_fit = a + b * X + return y_fit, (a, b) + + try: + from numba import jit + + return jit(linear_fit) + + except ImportError: + return linear_fit diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 463e01b..4950d96 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -27,20 +27,23 @@ class Node: ---------- value : float idx_data_points : Optional[npt.NDArray[np.int_]] - idx_split_variable : Optional[npt.NDArray[np.int_]] + idx_split_variable : int + linear_params: Optional[List[float]] = None """ - __slots__ = "value", "idx_split_variable", "idx_data_points" + __slots__ = "value", "idx_split_variable", "idx_data_points", "linear_params" def __init__( self, value: float = -1.0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, + linear_params: Optional[List[float]] = None, ) -> None: self.value = value self.idx_data_points = idx_data_points self.idx_split_variable = idx_split_variable + self.linear_params = linear_params @classmethod def new_leaf_node(cls, value: float, idx_data_points: Optional[npt.NDArray[np.int_]]) -> "Node": @@ -180,7 +183,7 @@ def _predict(self) -> npt.NDArray[np.float_]: return output.T def predict( - self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None + self, x: npt.NDArray[np.float_], m: int, excluded: Optional[List[int]] = None ) -> npt.NDArray[np.float_]: """ Predict output of tree for an (un)observed point x. @@ -189,6 +192,8 @@ def predict( ---------- x : npt.NDArray[np.float_] Unobserved point + m : int + Number of trees excluded: Optional[List[int]] Indexes of the variables to exclude when computing predictions @@ -199,12 +204,14 @@ def predict( """ if excluded is None: excluded = [] - return self._traverse_tree(x, 0, excluded) + return self._traverse_tree(x=x, m=m, node_index=0, split_variable=-1, excluded=excluded) def _traverse_tree( self, x: npt.NDArray[np.float_], + m: int, node_index: int, + split_variable: int = -1, excluded: Optional[List[int]] = None, ) -> npt.NDArray[np.float_]: """ @@ -214,8 +221,12 @@ def _traverse_tree( ---------- x : npt.NDArray[np.float_] Unobserved point + m : int + Number of trees node_index : int Index of the node to start the traversal from + split_variable : int + Index of the variable used to split the node excluded: Optional[List[int]] Indexes of the variables to exclude when computing predictions @@ -224,9 +235,16 @@ def _traverse_tree( npt.NDArray[np.float_] Leaf node value or mean of leaf node values """ - current_node: Node = self.get_node(node_index) + current_node = self.get_node(node_index) if current_node.is_leaf_node(): - return np.array(current_node.value) + if current_node.linear_params is None: + return np.array(current_node.value) + + x = x[split_variable].item() + y_x = current_node.linear_params[0] + current_node.linear_params[1] * x + return np.array(y_x / m) + + split_variable = current_node.idx_split_variable if excluded is not None and current_node.idx_split_variable in excluded: leaf_values: List[float] = [] @@ -237,7 +255,7 @@ def _traverse_tree( next_node = get_idx_left_child(node_index) else: next_node = get_idx_right_child(node_index) - return self._traverse_tree(x=x, node_index=next_node, excluded=excluded) + return self._traverse_tree(x=x, node_index=next_node, split_variable=split_variable, excluded=excluded) def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None: """ diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 45b9a75..d530196 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -21,6 +21,7 @@ def _sample_posterior( all_trees: List[List[Tree]], X: TensorLike, + m: int, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None, excluded: Optional[npt.NDArray[np.int_]] = None, @@ -35,6 +36,8 @@ def _sample_posterior( X : tensor-like A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for out-of-sample predictions. + m : int + Number of trees rng : NumPy RandomGenerator size : int or tuple Number of samples. @@ -57,7 +60,7 @@ def _sample_posterior( flatten_size *= s idx = rng.integers(0, len(stacked_trees), size=flatten_size) - shape = stacked_trees[0][0].predict(X[0]).size + shape = stacked_trees[0][0].predict(x=X[0], m=m).size pred = np.zeros((flatten_size, X.shape[0], shape)) @@ -220,6 +223,8 @@ def plot_dependence( ------- axes: matplotlib axes """ + m: int = bartrv.owner.op.m + if kind not in ["pdp", "ice"]: raise ValueError(f"kind={kind} is not suported. Available option are 'pdp' or 'ice'") @@ -294,7 +299,7 @@ def plot_dependence( new_X[:, indices_mi] = X[:, indices_mi] new_X[:, i] = x_i y_pred.append( - np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 1) + np.mean(_sample_posterior(all_trees, X=new_X, m=m, rng=rng, size=samples), 1) ) new_x_target.append(new_x_i) else: @@ -302,7 +307,7 @@ def plot_dependence( new_X = X[idx_s] new_X[:, indices_mi] = X[:, indices_mi][instance] y_pred.append( - np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 0) + np.mean(_sample_posterior(all_trees, X=new_X, m=m, rng=rng, size=samples), 0) ) new_x_target.append(new_X[:, i]) y_mins.append(np.min(y_pred)) @@ -445,6 +450,8 @@ def plot_variable_importance( """ _, axes = plt.subplots(2, 1, figsize=figsize) + m: int = bartrv.owner.op.m + if hasattr(X, "columns") and hasattr(X, "values"): labels = X.columns X = X.values @@ -474,13 +481,13 @@ def plot_variable_importance( all_trees = bartrv.owner.op.all_trees - predicted_all = _sample_posterior(all_trees, X=X, rng=rng, size=samples, excluded=None) + predicted_all = _sample_posterior(all_trees, X=X, m=m, rng=rng, size=samples, excluded=None) ev_mean = np.zeros(len(var_imp)) ev_hdi = np.zeros((len(var_imp), 2)) for idx, subset in enumerate(subsets): predicted_subset = _sample_posterior( - all_trees=all_trees, X=X, rng=rng, size=samples, excluded=subset + all_trees=all_trees, X=X, m=m, rng=rng, size=samples, excluded=subset ) pearson = np.zeros(samples) for j in range(samples): diff --git a/tests/test_bart.py b/tests/test_bart.py index d8a3a27..f08bb31 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -119,9 +119,13 @@ class TestUtils: def test_sample_posterior(self): all_trees = self.mu.owner.op.all_trees rng = np.random.default_rng(3) - pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) + pred_all = pmb.utils._sample_posterior( + all_trees, X=self.X, m=self.mu.owner.op.m, rng=rng, size=2 + ) rng = np.random.default_rng(3) - pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) + pred_first = pmb.utils._sample_posterior( + all_trees, X=self.X[:10], m=self.mu.owner.op.m, rng=rng + ) assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50, 1) From ce3ed631ccbe99b1eb7049cffe672842424c1f77 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 13 Apr 2023 14:34:07 +0200 Subject: [PATCH 02/29] fix lint --- pymc_bart/tree.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 4950d96..70183a0 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -255,7 +255,9 @@ def _traverse_tree( next_node = get_idx_left_child(node_index) else: next_node = get_idx_right_child(node_index) - return self._traverse_tree(x=x, node_index=next_node, split_variable=split_variable, excluded=excluded) + return self._traverse_tree( + x=x, m=m, node_index=next_node, split_variable=split_variable, excluded=excluded + ) def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None: """ From 911ba739826b02f2294c8dfeca334cd62258c1fd Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Thu, 11 May 2023 20:11:36 +0200 Subject: [PATCH 03/29] Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin --- pymc_bart/pgbart.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index c2aac95..87e6f53 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -611,25 +611,14 @@ def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin return function -def fast_linear_fit(): - """If available use Numba to speed up the computation of the linear fit""" +def fast_linear_fit(X, Y): + """Use Numba to speed up the computation of the linear fit""" + n = len(Y) + xbar = np.sum(X) / n + ybar = np.sum(Y) / n - def linear_fit(X, Y): + b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar**2) + a = ybar - b * xbar - n = len(Y) - xbar = np.sum(X) / n - ybar = np.sum(Y) / n - - b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar**2) - a = ybar - b * xbar - - y_fit = a + b * X - return y_fit, (a, b) - - try: - from numba import jit - - return jit(linear_fit) - - except ImportError: - return linear_fit + y_fit = a + b * X + return y_fit, (a, b) From 613d0ecc6b65a69dcdf935a562f8ef5d68e385c7 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 11 May 2023 21:34:53 +0200 Subject: [PATCH 04/29] address comments --- .pylintrc | 3 +++ pymc_bart/pgbart.py | 62 ++++++++++++++++++++++++--------------------- pymc_bart/tree.py | 15 +++++++++-- 3 files changed, 49 insertions(+), 31 deletions(-) diff --git a/.pylintrc b/.pylintrc index de9c879..322f29a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -256,6 +256,9 @@ good-names=i, rv, new_X, new_y, + a, + b, + n, # Include a hint for the correct naming format with invalid-name diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 87e6f53..2e9f611 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -54,7 +54,6 @@ def sample_tree( missing_data, sum_trees, m, - linear_fit, response, normal, shape, @@ -75,7 +74,6 @@ def sample_tree( missing_data, sum_trees, m, - linear_fit, response, normal, self.kfactor, @@ -166,8 +164,6 @@ def __init__( shape=self.shape, ) - self.linear_fit = fast_linear_fit() - self.normal = NormalSampler(mu_std, self.shape) self.uniform = UniformSampler(0, 1) self.uniform_kf = UniformSampler(0.33, 0.75, self.shape) @@ -217,7 +213,6 @@ def astep(self, _): self.missing_data, self.sum_trees, self.m, - self.linear_fit, self.response, self.normal, self.shape, @@ -403,7 +398,6 @@ def grow_tree( missing_data, sum_trees, m, - linear_fit, response, normal, kfactor, @@ -429,16 +423,19 @@ def grow_tree( for idx in range(2): idx_data_point = new_idx_data_points[idx] - node_value = draw_leaf_value( - sum_trees[:, idx_data_point], - m, - normal.rvs() * kfactor, - shape, + node_value, linear_parms = draw_leaf_value( + y_mu_pred=X[idx_data_point, selected_predictor], + x_mu=sum_trees[:, idx_data_point], + m=m, + norm=normal.rvs() * kfactor, + shape=shape, + response=response, ) new_node = Node.new_leaf_node( value=node_value, idx_data_points=idx_data_point, + linear_params=linear_parms, ) tree.set_node(current_node_children[idx], new_node) @@ -468,17 +465,23 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data): @njit -def draw_leaf_value(y_mu_pred, m, norm, shape): +def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): """Draw Gaussian distributed leaf values.""" + linear_params = None if y_mu_pred.size == 0: - return np.zeros(shape) + mu_mean = np.zeros(shape) - if y_mu_pred.size == 1: + elif y_mu_pred.size == 1: mu_mean = np.full(shape, y_mu_pred.item() / m) - else: + + elif response == "constant": mu_mean = fast_mean(y_mu_pred) / m - return norm + mu_mean + elif response == "linear": + y_fit, _ = fast_linear_fit(x=x_mu, y=y_mu_pred) + mu_mean = y_fit / m + + return norm + mu_mean, linear_params @njit @@ -500,6 +503,20 @@ def fast_mean(ari): return res / count +@njit +def fast_linear_fit(x, y): + """Use Numba to speed up the computation of the linear fit""" + n = len(x) + xbar = np.sum(x) / n + ybar = np.sum(y) / n + + b = (x @ y - n * xbar * ybar) / (x @ x - n * xbar**2) + a = ybar - b * xbar + + y_fit = a + b * x + return y_fit, (a, b) + + def discrete_uniform_sampler(upper_value): """Draw from the uniform distribution with bounds [0, upper_value). @@ -609,16 +626,3 @@ def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin function = pytensor_function([inarray0], out_list[0]) function.trust_input = True return function - - -def fast_linear_fit(X, Y): - """Use Numba to speed up the computation of the linear fit""" - n = len(Y) - xbar = np.sum(X) / n - ybar = np.sum(Y) / n - - b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar**2) - a = ybar - b * xbar - - y_fit = a + b * X - return y_fit, (a, b) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 70183a0..88ca4d8 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -46,8 +46,19 @@ def __init__( self.linear_params = linear_params @classmethod - def new_leaf_node(cls, value: float, idx_data_points: Optional[npt.NDArray[np.int_]]) -> "Node": - return cls(value=value, idx_data_points=idx_data_points) + def new_leaf_node( + cls, + value: float, + idx_data_points: Optional[npt.NDArray[np.int_]] = None, + idx_split_variable: int = -1, + linear_params: Optional[List[float]] = None, + ) -> "Node": + return cls( + value=value, + idx_data_points=idx_data_points, + idx_split_variable=idx_split_variable, + linear_params=linear_params, + ) @classmethod def new_split_node(cls, split_value: float, idx_split_variable: int) -> "Node": From e9c67c9555964503bbf63d599c1d16e4b5524412 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 11 May 2023 22:00:33 +0200 Subject: [PATCH 05/29] fix arg error --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 2e9f611..f0535b8 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -478,7 +478,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): mu_mean = fast_mean(y_mu_pred) / m elif response == "linear": - y_fit, _ = fast_linear_fit(x=x_mu, y=y_mu_pred) + y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) mu_mean = y_fit / m return norm + mu_mean, linear_params From e5c8bd3a596f1f8b77922f05ea7e6b98fde1a94e Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 11 May 2023 22:20:04 +0200 Subject: [PATCH 06/29] add bracket --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index f0535b8..0e301ca 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -481,7 +481,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) mu_mean = y_fit / m - return norm + mu_mean, linear_params + return (norm + mu_mean), linear_params @njit From 64dcd07400f3d14d90c271525854db2171def63d Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 12 May 2023 14:57:42 +0200 Subject: [PATCH 07/29] clean func --- pymc_bart/pgbart.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 0e301ca..a94a5b3 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -423,7 +423,7 @@ def grow_tree( for idx in range(2): idx_data_point = new_idx_data_points[idx] - node_value, linear_parms = draw_leaf_value( + node_value, linear_params = draw_leaf_value( y_mu_pred=X[idx_data_point, selected_predictor], x_mu=sum_trees[:, idx_data_point], m=m, @@ -435,7 +435,7 @@ def grow_tree( new_node = Node.new_leaf_node( value=node_value, idx_data_points=idx_data_point, - linear_params=linear_parms, + linear_params=linear_params, ) tree.set_node(current_node_children[idx], new_node) @@ -467,21 +467,22 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data): @njit def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): """Draw Gaussian distributed leaf values.""" - linear_params = None if y_mu_pred.size == 0: mu_mean = np.zeros(shape) elif y_mu_pred.size == 1: mu_mean = np.full(shape, y_mu_pred.item() / m) - elif response == "constant": - mu_mean = fast_mean(y_mu_pred) / m + else: + if response == "constant": + mu_mean = fast_mean(y_mu_pred) / m - elif response == "linear": - y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = y_fit / m + elif response == "linear": + y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) + mu_mean = fast_mean(y_fit) / m - return (norm + mu_mean), linear_params + draws = norm + mu_mean + return draws, linear_params @njit From 201b85e145e465aed2fd320b545a98863db74bf5 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 12 May 2023 15:03:47 +0200 Subject: [PATCH 08/29] fix rm init param linear_params --- pymc_bart/pgbart.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index a94a5b3..cf07b69 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -467,6 +467,7 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data): @njit def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): """Draw Gaussian distributed leaf values.""" + linear_params = None if y_mu_pred.size == 0: mu_mean = np.zeros(shape) From 7ff4a21923af02ffa28e26ad83fc78cc05b54429 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 12 May 2023 15:31:13 +0200 Subject: [PATCH 09/29] clean up --- pymc_bart/pgbart.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index cf07b69..3fd8dce 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -469,21 +469,18 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): """Draw Gaussian distributed leaf values.""" linear_params = None if y_mu_pred.size == 0: - mu_mean = np.zeros(shape) + return np.zeros(shape), linear_params - elif y_mu_pred.size == 1: + if y_mu_pred.size == 1: mu_mean = np.full(shape, y_mu_pred.item() / m) - else: - if response == "constant": - mu_mean = fast_mean(y_mu_pred) / m - - elif response == "linear": - y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) + if response == "linear": + y_fit, linear_params = fast_linear_fit(x_mu, y_mu_pred) mu_mean = fast_mean(y_fit) / m + else: + mu_mean = fast_mean(y_mu_pred) / m - draws = norm + mu_mean - return draws, linear_params + return norm + mu_mean, linear_params @njit From 6c327df98bb4f2b6081fd4926c15eae574059183 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 12 May 2023 22:15:32 +0200 Subject: [PATCH 10/29] minor improvements --- pymc_bart/pgbart.py | 14 +++++++------- tests/test_pgbart.py | 5 ++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 3fd8dce..c31b149 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -424,8 +424,8 @@ def grow_tree( for idx in range(2): idx_data_point = new_idx_data_points[idx] node_value, linear_params = draw_leaf_value( - y_mu_pred=X[idx_data_point, selected_predictor], - x_mu=sum_trees[:, idx_data_point], + y_mu_pred=sum_trees[:, idx_data_point], + x_mu=X[idx_data_point, selected_predictor], m=m, norm=normal.rvs() * kfactor, shape=shape, @@ -475,12 +475,12 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): mu_mean = np.full(shape, y_mu_pred.item() / m) else: if response == "linear": - y_fit, linear_params = fast_linear_fit(x_mu, y_mu_pred) - mu_mean = fast_mean(y_fit) / m + y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) + mu_mean = fast_mean(y_fit / m) else: mu_mean = fast_mean(y_mu_pred) / m - - return norm + mu_mean, linear_params + draw = norm + mu_mean + return draw, linear_params @njit @@ -513,7 +513,7 @@ def fast_linear_fit(x, y): a = ybar - b * xbar y_fit = a + b * x - return y_fit, (a, b) + return y_fit, [a, b] def discrete_uniform_sampler(upper_value): diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index ec9fbac..5bdce63 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -4,8 +4,7 @@ import pymc as pm import pymc_bart as pmb -from pymc_bart.pgbart import (NormalSampler, UniformSampler, - discrete_uniform_sampler, fast_mean) +from pymc_bart.pgbart import NormalSampler, UniformSampler, discrete_uniform_sampler, fast_mean class TestSystematic(TestCase): @@ -24,7 +23,7 @@ def test_systematic(self): indices = step.systematic(normalized_weights) self.assertEqual(len(indices), len(normalized_weights)) - self.assertEqual(indices.dtype, np.int) + self.assertEqual(indices.dtype, np.int_) self.assertTrue(all(i >= 0 and i < len(normalized_weights) for i in indices)) normalized_weights = np.array([0, 0.25, 0.75]) From 83abae6289f9821cdccaf0ccfe0b55ec4f308f77 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 12 May 2023 22:19:03 +0200 Subject: [PATCH 11/29] fix mean denominator --- pymc_bart/pgbart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index c31b149..35d433b 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -468,6 +468,7 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data): def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): """Draw Gaussian distributed leaf values.""" linear_params = None + mu_mean = np.empty(shape) if y_mu_pred.size == 0: return np.zeros(shape), linear_params @@ -476,7 +477,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): else: if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = fast_mean(y_fit / m) + mu_mean = fast_mean(y_fit) / m else: mu_mean = fast_mean(y_mu_pred) / m draw = norm + mu_mean From 796c9a69ffd001187abb7fa2a9b6096798c33e63 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 12 May 2023 22:30:08 +0200 Subject: [PATCH 12/29] fix shape --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 35d433b..41234d9 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -477,7 +477,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): else: if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = fast_mean(y_fit) / m + mu_mean = fast_mean(y_fit.reshape(-1, 1)) / m else: mu_mean = fast_mean(y_mu_pred) / m draw = norm + mu_mean From f4e7074eb9ee74ffce5e5cd3600333bc12a8d1ef Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 12 May 2023 23:28:33 +0200 Subject: [PATCH 13/29] add tests --- tests/test_bart.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_bart.py b/tests/test_bart.py index f08bb31..ab71fd5 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -40,13 +40,18 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): assert np.isfinite(logp_moment) -def test_bart_vi(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant-response", "linear-response"], +) +def test_bart_vi(response): X = np.random.normal(0, 1, size=(250, 3)) Y = np.random.normal(0, 1, size=250) X[:, 0] = np.random.normal(Y, 0.1) with pm.Model() as model: - mu = pmb.BART("mu", X, Y, m=10) + mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(random_seed=3415) From 5b7f67c7b5c63a284132cafc0d5e19a01a41a566 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sat, 13 May 2023 21:59:13 +0200 Subject: [PATCH 14/29] fix working implementation --- pymc_bart/pgbart.py | 4 ++-- tests/test_bart.py | 31 +++++++++++++++++++++++-------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 41234d9..96d25fe 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -477,7 +477,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): else: if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = fast_mean(y_fit.reshape(-1, 1)) / m + mu_mean = fast_mean(y_fit.reshape(1, -1)) / m else: mu_mean = fast_mean(y_mu_pred) / m draw = norm + mu_mean @@ -510,7 +510,7 @@ def fast_linear_fit(x, y): xbar = np.sum(x) / n ybar = np.sum(y) / n - b = (x @ y - n * xbar * ybar) / (x @ x - n * xbar**2) + b = (x @ y.T - n * xbar * ybar) / (x @ x - n * xbar**2) a = ybar - b * xbar y_fit = a + b * x diff --git a/tests/test_bart.py b/tests/test_bart.py index ab71fd5..28e23f0 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -42,8 +42,8 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): @pytest.mark.parametrize( argnames="response", - argvalues=["constant", "linear"], - ids=["constant-response", "linear-response"], + argvalues=["linear"], + ids=["linear-response"], ) def test_bart_vi(response): X = np.random.normal(0, 1, size=(250, 3)) @@ -65,25 +65,35 @@ def test_bart_vi(response): assert_almost_equal(var_imp.sum(), 1) -def test_missing_data(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], +) +def test_missing_data(response): X = np.random.normal(0, 1, size=(50, 2)) Y = np.random.normal(0, 1, size=50) X[10:20, 0] = np.nan with pm.Model() as model: - mu = pmb.BART("mu", X, Y, m=10) + mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415) -def test_shared_variable(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], +) +def test_shared_variable(response): X = np.random.normal(0, 1, size=(50, 2)) Y = np.random.normal(0, 1, size=50) with pm.Model() as model: data_X = pm.MutableData("data_X", X) - mu = pmb.BART("mu", data_X, Y, m=2) + mu = pmb.BART("mu", data_X, Y, m=2, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape) idata = pm.sample(tune=100, draws=100, chains=2, random_seed=3415) @@ -95,12 +105,17 @@ def test_shared_variable(): assert ppc2.posterior_predictive["y"].shape == (2, 100, 3) -def test_shape(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], +) +def test_shape(response): X = np.random.normal(0, 1, size=(250, 3)) Y = np.random.normal(0, 1, size=250) with pm.Model() as model: - w = pmb.BART("w", X, Y, m=2, shape=(2, 250)) + w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250)) y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y) idata = pm.sample(random_seed=3415) From de714f9a39ea4f14fee6181ff1151f58595fc74a Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sat, 13 May 2023 22:42:31 +0200 Subject: [PATCH 15/29] try fix shapre vect --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 96d25fe..e66b612 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -477,7 +477,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): else: if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = fast_mean(y_fit.reshape(1, -1)) / m + mu_mean = fast_mean(y_fit.reshape(-1, 1)) / m else: mu_mean = fast_mean(y_mu_pred) / m draw = norm + mu_mean From 7ec226e1087ae278e15882b9fc4ad34c6a6071e5 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sat, 13 May 2023 23:37:30 +0200 Subject: [PATCH 16/29] undo last commit --- pymc_bart/pgbart.py | 2 +- tests/test_bart.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index e66b612..96d25fe 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -477,7 +477,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): else: if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = fast_mean(y_fit.reshape(-1, 1)) / m + mu_mean = fast_mean(y_fit.reshape(1, -1)) / m else: mu_mean = fast_mean(y_mu_pred) / m draw = norm + mu_mean diff --git a/tests/test_bart.py b/tests/test_bart.py index 28e23f0..3da9a2f 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -42,8 +42,8 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): @pytest.mark.parametrize( argnames="response", - argvalues=["linear"], - ids=["linear-response"], + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], ) def test_bart_vi(response): X = np.random.normal(0, 1, size=(250, 3)) From 89fd59e45fcdfdc4a8f78c7142ba481f6d181d73 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Tue, 16 May 2023 17:23:55 +0200 Subject: [PATCH 17/29] Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 96d25fe..a5d5685 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -477,7 +477,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): else: if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = fast_mean(y_fit.reshape(1, -1)) / m + mu_mean = y_fit.reshape(1, -1) / m else: mu_mean = fast_mean(y_mu_pred) / m draw = norm + mu_mean From 996596f2053bd4bb1b942a5b483d712d84b042c0 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 16 May 2023 21:20:22 +0200 Subject: [PATCH 18/29] address comments --- pymc_bart/pgbart.py | 9 ++++++--- pymc_bart/tree.py | 7 ++++++- pymc_bart/utils.py | 2 +- tests/test_pgbart.py | 25 +++++++++++++++++++++++-- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index a5d5685..b090513 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -421,6 +421,9 @@ def grow_tree( get_idx_right_child(index_leaf_node), ) + if response == "mix": + response = "linear" if np.random.random() >= 0.5 else "constant" + for idx in range(2): idx_data_point = new_idx_data_points[idx] node_value, linear_params = draw_leaf_value( @@ -477,7 +480,7 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): else: if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) - mu_mean = y_fit.reshape(1, -1) / m + mu_mean = y_fit / m else: mu_mean = fast_mean(y_mu_pred) / m draw = norm + mu_mean @@ -510,11 +513,11 @@ def fast_linear_fit(x, y): xbar = np.sum(x) / n ybar = np.sum(y) / n - b = (x @ y.T - n * xbar * ybar) / (x @ x - n * xbar**2) + b = ((x - xbar) @ (y - ybar).T) / ((x - xbar) @ (x - xbar).T) a = ybar - b * xbar y_fit = a + b * x - return y_fit, [a, b] + return y_fit, [a.item(), b.item()] def discrete_uniform_sampler(upper_value): diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 88ca4d8..aad6129 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -150,7 +150,12 @@ def __setitem__(self, index, node) -> None: def copy(self) -> "Tree": tree: Dict[int, Node] = { - k: Node(v.value, v.idx_data_points, v.idx_split_variable) + k: Node( + value=v.value, + idx_data_points=v.idx_data_points, + idx_split_variable=v.idx_split_variable, + linear_params=v.linear_params, + ) for k, v in self.tree_structure.items() } idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index d530196..60e6c7d 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -66,7 +66,7 @@ def _sample_posterior( for ind, p in enumerate(pred): for tree in stacked_trees[idx[ind]]: - p += np.vstack([tree.predict(x, excluded) for x in X]) + p += np.vstack([tree.predict(x=x, m=m, excluded=excluded) for x in X]) pred.reshape((*size_iter, shape, -1)) return pred diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index 5bdce63..fddc50e 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -1,10 +1,16 @@ from unittest import TestCase - +import pytest import numpy as np import pymc as pm import pymc_bart as pmb -from pymc_bart.pgbart import NormalSampler, UniformSampler, discrete_uniform_sampler, fast_mean +from pymc_bart.pgbart import ( + NormalSampler, + UniformSampler, + discrete_uniform_sampler, + fast_mean, + fast_linear_fit, +) class TestSystematic(TestCase): @@ -39,6 +45,21 @@ def test_fast_mean(): np.testing.assert_array_almost_equal(fast_mean(values), np.mean(values, 1)) +@pytest.mark.parametrize( + argnames="x,y,a_expected, b_expected", + argvalues=[ + (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5]), 0.0, 1.0), + (np.array([1, 2, 3, 4, 5]), np.array([1, 1, 1, 1, 1]), 1.0, 0.0), + ], + ids=["1d-id", "1d-const"], +) +def test_fast_linear_fit(x, y, a_expected, b_expected): + y_fit, linear_params = fast_linear_fit(x, y) + assert linear_params[0] == a_expected + assert linear_params[1] == b_expected + np.testing.assert_almost_equal(actual=y_fit, desired=a_expected + x * b_expected) + + def test_discrete_uniform(): sample = discrete_uniform_sampler(7) assert isinstance(sample, int) From 0c0209d8b86068a8b9e320c0cdbbf710bbcd59e0 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 16 May 2023 21:45:48 +0200 Subject: [PATCH 19/29] change default response --- pymc_bart/bart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 7839d85..b10bbc6 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -74,7 +74,7 @@ class BART(Distribution): Number of trees response : str How the leaf_node values are computed. Available options are ``constant``, ``linear`` or - ``mix`` (default). + ``mix``. Defaults to ``constant``. alpha : float Control the prior probability over the depth of the trees. Even when it can takes values in the interval (0, 1), it is recommended to be in the interval (0, 0.5]. @@ -91,7 +91,7 @@ def __new__( Y: TensorLike, m: int = 50, alpha: float = 0.25, - response: str = "mix", + response: str = "constant", split_prior: Optional[List[float]] = None, **kwargs, ): From f4fcce14d720f0cf5203cbd24541abe0c6f73999 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 17 May 2023 00:12:44 +0200 Subject: [PATCH 20/29] better unpacking --- pymc_bart/pgbart.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index b090513..a1a72a7 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -426,11 +426,15 @@ def grow_tree( for idx in range(2): idx_data_point = new_idx_data_points[idx] - node_value, linear_params = draw_leaf_value( - y_mu_pred=sum_trees[:, idx_data_point], - x_mu=X[idx_data_point, selected_predictor], + y_mu_pred = sum_trees[:, idx_data_point] + x_mu = X[idx_data_point, selected_predictor] + norm = normal.rvs() * kfactor + + node_value, *linear_params = draw_leaf_value( + y_mu_pred=y_mu_pred, + x_mu=x_mu, m=m, - norm=normal.rvs() * kfactor, + norm=norm, shape=shape, response=response, ) @@ -517,7 +521,7 @@ def fast_linear_fit(x, y): a = ybar - b * xbar y_fit = a + b * x - return y_fit, [a.item(), b.item()] + return y_fit, [a, b] def discrete_uniform_sampler(upper_value): From db767ce89e1b473ba33be5a626788d4568e33e6d Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 17 May 2023 00:52:25 +0200 Subject: [PATCH 21/29] some type hints --- pymc_bart/pgbart.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index a1a72a7..1d6c3a8 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -426,15 +426,11 @@ def grow_tree( for idx in range(2): idx_data_point = new_idx_data_points[idx] - y_mu_pred = sum_trees[:, idx_data_point] - x_mu = X[idx_data_point, selected_predictor] - norm = normal.rvs() * kfactor - - node_value, *linear_params = draw_leaf_value( - y_mu_pred=y_mu_pred, - x_mu=x_mu, + node_value, linear_params = draw_leaf_value( + y_mu_pred=sum_trees[:, idx_data_point], + x_mu=X[idx_data_point, selected_predictor], m=m, - norm=norm, + norm=normal.rvs() * kfactor, shape=shape, response=response, ) @@ -472,7 +468,14 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data): @njit -def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): +def draw_leaf_value( + y_mu_pred: npt.NDArray[np.float_], + x_mu: npt.NDArray[np.float_], + m: int, + norm: npt.NDArray[np.float_], + shape: int, + response: str, +) -> Tuple[npt.NDArray[np.float_], Optional[List[float]]]: """Draw Gaussian distributed leaf values.""" linear_params = None mu_mean = np.empty(shape) @@ -482,17 +485,18 @@ def draw_leaf_value(y_mu_pred, x_mu, m, norm, shape, response): if y_mu_pred.size == 1: mu_mean = np.full(shape, y_mu_pred.item() / m) else: + if response == "constant": + mu_mean = fast_mean(y_mu_pred) / m if response == "linear": y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) mu_mean = y_fit / m - else: - mu_mean = fast_mean(y_mu_pred) / m + draw = norm + mu_mean return draw, linear_params @njit -def fast_mean(ari): +def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]: """Use Numba to speed up the computation of the mean.""" if ari.ndim == 1: @@ -511,7 +515,9 @@ def fast_mean(ari): @njit -def fast_linear_fit(x, y): +def fast_linear_fit( + x: npt.NDArray[np.float_], y: npt.NDArray[np.float_] +) -> Tuple[npt.NDArray[np.float_], List[float]]: """Use Numba to speed up the computation of the linear fit""" n = len(x) xbar = np.sum(x) / n @@ -521,7 +527,7 @@ def fast_linear_fit(x, y): a = ybar - b * xbar y_fit = a + b * x - return y_fit, [a, b] + return y_fit, [a.item(), b.item()] def discrete_uniform_sampler(upper_value): From 6d4678423b2d7827d5af1289f93e21e6c2906d67 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 18 May 2023 23:20:59 +0200 Subject: [PATCH 22/29] fix trim method --- pymc_bart/tree.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index aad6129..1d07e2f 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -180,7 +180,13 @@ def grow_leaf_node( def trim(self) -> "Tree": tree: Dict[int, Node] = { - k: Node(v.value, None, v.idx_split_variable) for k, v in self.tree_structure.items() + k: Node( + value=v.value, + idx_data_points=None, + idx_split_variable=v.idx_split_variable, + linear_params=v.linear_params, + ) + for k, v in self.tree_structure.items() } return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1])) From c2de9fb3acef3e611539bb5b2ec2bc1f393a279e Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 23 May 2023 21:08:43 +0200 Subject: [PATCH 23/29] handle zero variance fast linear fit --- pymc_bart/pgbart.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 1d6c3a8..0808b5f 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -523,7 +523,16 @@ def fast_linear_fit( xbar = np.sum(x) / n ybar = np.sum(y) / n - b = ((x - xbar) @ (y - ybar).T) / ((x - xbar) @ (x - xbar).T) + x_diff = x - xbar + y_diff = y - ybar + + x_var = x_diff @ x_diff.T + + if x_var == 0: + b = np.zeros(1) + else: + b = (x_diff @ y_diff.T) / x_var + a = ybar - b * xbar y_fit = a + b * x From 8c6963d27a063836273bcdb91abbf433771598f2 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 23 May 2023 22:31:24 +0200 Subject: [PATCH 24/29] simplify condition --- pymc_bart/pgbart.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 0808b5f..ff7ea3a 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -528,9 +528,7 @@ def fast_linear_fit( x_var = x_diff @ x_diff.T - if x_var == 0: - b = np.zeros(1) - else: + if x_var != 0: b = (x_diff @ y_diff.T) / x_var a = ybar - b * xbar From f2c546a33059fd8f96375f5fb8a81b2f364a7fc0 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 25 May 2023 01:07:59 -0300 Subject: [PATCH 25/29] linear response in more than 1d --- pymc_bart/pgbart.py | 22 +++++++++++----------- pymc_bart/tree.py | 2 +- tests/test_pgbart.py | 8 +++++--- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index ff7ea3a..f9b4c76 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -467,7 +467,6 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data): return split_value -@njit def draw_leaf_value( y_mu_pred: npt.NDArray[np.float_], x_mu: npt.NDArray[np.float_], @@ -475,7 +474,7 @@ def draw_leaf_value( norm: npt.NDArray[np.float_], shape: int, response: str, -) -> Tuple[npt.NDArray[np.float_], Optional[List[float]]]: +) -> Tuple[npt.NDArray[np.float_], Optional[npt.NDArray[np.float_]]]: """Draw Gaussian distributed leaf values.""" linear_params = None mu_mean = np.empty(shape) @@ -498,7 +497,6 @@ def draw_leaf_value( @njit def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]: """Use Numba to speed up the computation of the mean.""" - if ari.ndim == 1: count = ari.shape[0] suma = 0 @@ -518,23 +516,25 @@ def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_ def fast_linear_fit( x: npt.NDArray[np.float_], y: npt.NDArray[np.float_] ) -> Tuple[npt.NDArray[np.float_], List[float]]: - """Use Numba to speed up the computation of the linear fit""" n = len(x) + xbar = np.sum(x) / n - ybar = np.sum(y) / n + ybar = np.sum(y, axis=1) / n x_diff = x - xbar - y_diff = y - ybar + y_diff = y - np.expand_dims(ybar, axis=1) - x_var = x_diff @ x_diff.T + x_var = np.dot(x_diff, x_diff.T) - if x_var != 0: - b = (x_diff @ y_diff.T) / x_var + if x_var == 0: + b = np.zeros(y.shape[0]) + else: + b = np.dot(x_diff, y_diff.T) / x_var a = ybar - b * xbar - y_fit = a + b * x - return y_fit, [a.item(), b.item()] + y_fit = np.expand_dims(a, axis=1) + np.expand_dims(b, axis=1) * x + return y_fit.T, [a, b] def discrete_uniform_sampler(upper_value): diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 1d07e2f..9cb8088 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -201,7 +201,7 @@ def _predict(self) -> npt.NDArray[np.float_]: if self.idx_leaf_nodes is not None: for node_index in self.idx_leaf_nodes: leaf_node = self.get_node(node_index) - output[leaf_node.idx_data_points] = leaf_node.value + output[leaf_node.idx_data_points] = leaf_node.value.squeeze() return output.T def predict( diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index fddc50e..841b8be 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -48,8 +48,8 @@ def test_fast_mean(): @pytest.mark.parametrize( argnames="x,y,a_expected, b_expected", argvalues=[ - (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5]), 0.0, 1.0), - (np.array([1, 2, 3, 4, 5]), np.array([1, 1, 1, 1, 1]), 1.0, 0.0), + (np.array([1, 2, 3, 4, 5]), np.array([[1, 2, 3, 4, 5]]), 0.0, 1.0), + (np.array([1, 2, 3, 4, 5]), np.array([[1, 1, 1, 1, 1]]), 1.0, 0.0), ], ids=["1d-id", "1d-const"], ) @@ -57,7 +57,9 @@ def test_fast_linear_fit(x, y, a_expected, b_expected): y_fit, linear_params = fast_linear_fit(x, y) assert linear_params[0] == a_expected assert linear_params[1] == b_expected - np.testing.assert_almost_equal(actual=y_fit, desired=a_expected + x * b_expected) + np.testing.assert_almost_equal( + actual=y_fit, desired=np.atleast_2d(a_expected + x * b_expected).T + ) def test_discrete_uniform(): From 779af0d6697075cb15af9365bae9758fecb7c821 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 29 May 2023 20:50:29 +0200 Subject: [PATCH 26/29] make node value type an array so that is compatible with vectorization --- pymc_bart/pgbart.py | 2 +- pymc_bart/tree.py | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index f9b4c76..3523008 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -534,7 +534,7 @@ def fast_linear_fit( a = ybar - b * xbar y_fit = np.expand_dims(a, axis=1) + np.expand_dims(b, axis=1) * x - return y_fit.T, [a, b] + return y_fit.T, [a.item(), b.item()] def discrete_uniform_sampler(upper_value): diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 9cb8088..9507965 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -25,7 +25,7 @@ class Node: Attributes ---------- - value : float + value : npt.NDArray[np.float_] idx_data_points : Optional[npt.NDArray[np.int_]] idx_split_variable : int linear_params: Optional[List[float]] = None @@ -35,7 +35,7 @@ class Node: def __init__( self, - value: float = -1.0, + value: npt.NDArray[np.float_] = np.array([-1.0]), idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, linear_params: Optional[List[float]] = None, @@ -48,7 +48,7 @@ def __init__( @classmethod def new_leaf_node( cls, - value: float, + value: npt.NDArray[np.float_], idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, linear_params: Optional[List[float]] = None, @@ -61,7 +61,7 @@ def new_leaf_node( ) @classmethod - def new_split_node(cls, split_value: float, idx_split_variable: int) -> "Node": + def new_split_node(cls, split_value: npt.NDArray[np.float_], idx_split_variable: int) -> "Node": return cls(value=split_value, idx_split_variable=idx_split_variable) def is_split_node(self) -> bool: @@ -129,7 +129,7 @@ def __init__( @classmethod def new_tree( cls, - leaf_node_value: float, + leaf_node_value: npt.NDArray[np.float_], idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, @@ -170,7 +170,11 @@ def set_node(self, index: int, node: Node) -> None: self.idx_leaf_nodes.append(index) def grow_leaf_node( - self, current_node: Node, selected_predictor: int, split_value: float, index_leaf_node: int + self, + current_node: Node, + selected_predictor: int, + split_value: npt.NDArray[np.float_], + index_leaf_node: int, ) -> None: current_node.value = split_value current_node.idx_split_variable = selected_predictor @@ -269,7 +273,7 @@ def _traverse_tree( split_variable = current_node.idx_split_variable if excluded is not None and current_node.idx_split_variable in excluded: - leaf_values: List[float] = [] + leaf_values: List[npt.NDArray[np.float_]] = [] self._traverse_leaf_values(leaf_values, node_index) return np.mean(leaf_values, axis=0) @@ -281,13 +285,15 @@ def _traverse_tree( x=x, m=m, node_index=next_node, split_variable=split_variable, excluded=excluded ) - def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None: + def _traverse_leaf_values( + self, leaf_values: List[npt.NDArray[np.float_]], node_index: int + ) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- - leaf_values : List[float] + leaf_values : List[npt.NDArray[np.float_]] node_index : int """ node = self.get_node(node_index) From 701d8e1da4e0677ee36dd3bd289e5c81d6a10c7e Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 29 May 2023 20:58:19 +0200 Subject: [PATCH 27/29] rm .item() from fast linear mdoel output --- pymc_bart/pgbart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 3523008..bb1d509 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -515,7 +515,7 @@ def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_ @njit def fast_linear_fit( x: npt.NDArray[np.float_], y: npt.NDArray[np.float_] -) -> Tuple[npt.NDArray[np.float_], List[float]]: +) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]: n = len(x) xbar = np.sum(x) / n @@ -534,7 +534,7 @@ def fast_linear_fit( a = ybar - b * xbar y_fit = np.expand_dims(a, axis=1) + np.expand_dims(b, axis=1) * x - return y_fit.T, [a.item(), b.item()] + return y_fit.T, [a, b] def discrete_uniform_sampler(upper_value): From c24262803a5d443145bf423502b7ec2cfa434c1e Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 29 May 2023 23:01:11 +0200 Subject: [PATCH 28/29] improve missing values filters --- pymc_bart/pgbart.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index bb1d509..d378dab 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -408,8 +408,10 @@ def grow_tree( index_selected_predictor = ssv.rvs() selected_predictor = available_predictors[index_selected_predictor] - available_splitting_values = X[idx_data_points, selected_predictor] - split_value = get_split_value(available_splitting_values, idx_data_points, missing_data) + idx_data_points, available_splitting_values = filter_missing_values( + X[idx_data_points, selected_predictor], idx_data_points, missing_data + ) + split_value = get_split_value(available_splitting_values) if split_value is None: return None @@ -452,13 +454,15 @@ def get_new_idx_data_points(available_splitting_values, split_value, idx_data_po return idx_data_points[split_idx], idx_data_points[~split_idx] -def get_split_value(available_splitting_values, idx_data_points, missing_data): +def filter_missing_values(available_splitting_values, idx_data_points, missing_data): if missing_data: - idx_data_points = idx_data_points[~np.isnan(available_splitting_values)] - available_splitting_values = available_splitting_values[ - ~np.isnan(available_splitting_values) - ] + mask = ~np.isnan(available_splitting_values) + idx_data_points = idx_data_points[mask] + available_splitting_values = available_splitting_values[mask] + return idx_data_points, available_splitting_values + +def get_split_value(available_splitting_values): split_value = None if available_splitting_values.size > 0: idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) From c08f86b7be9eafa31dadca24720396c0df368edb Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 29 May 2023 23:01:38 +0200 Subject: [PATCH 29/29] improve missing values filters --- pymc_bart/pgbart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index d378dab..dc91c59 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -467,7 +467,6 @@ def get_split_value(available_splitting_values): if available_splitting_values.size > 0: idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) split_value = available_splitting_values[idx_selected_splitting_values] - return split_value