diff --git a/.pylintrc b/.pylintrc index de9c879..322f29a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -256,6 +256,9 @@ good-names=i, rv, new_X, new_y, + a, + b, + n, # Include a hint for the correct naming format with invalid-name diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index b7f688f..b10bbc6 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -52,7 +52,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, else: return np.full(cls.Y.shape[0], cls.Y.mean()) else: - return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze().T + return _sample_posterior(cls.all_trees, cls.X, cls.m, rng=rng).squeeze().T bart = BARTRV() @@ -72,6 +72,9 @@ class BART(Distribution): The response vector. m : int Number of trees + response : str + How the leaf_node values are computed. Available options are ``constant``, ``linear`` or + ``mix``. Defaults to ``constant``. alpha : float Control the prior probability over the depth of the trees. Even when it can takes values in the interval (0, 1), it is recommended to be in the interval (0, 0.5]. @@ -88,6 +91,7 @@ def __new__( Y: TensorLike, m: int = 50, alpha: float = 0.25, + response: str = "constant", split_prior: Optional[List[float]] = None, **kwargs, ): @@ -110,6 +114,7 @@ def __new__( X=X, Y=Y, m=m, + response=response, alpha=alpha, split_prior=split_prior, ), diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 81ea8b1..dc91c59 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -54,6 +54,7 @@ def sample_tree( missing_data, sum_trees, m, + response, normal, shape, ) -> bool: @@ -73,6 +74,7 @@ def sample_tree( missing_data, sum_trees, m, + response, normal, self.kfactor, shape, @@ -131,6 +133,7 @@ def __init__( self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m + self.response = self.bart.response shape = initial_values[value_bart.name].shape self.shape = 1 if len(shape) == 1 else shape[0] @@ -160,6 +163,7 @@ def __init__( num_observations=self.num_observations, shape=self.shape, ) + self.normal = NormalSampler(mu_std, self.shape) self.uniform = UniformSampler(0, 1) self.uniform_kf = UniformSampler(0.33, 0.75, self.shape) @@ -209,6 +213,7 @@ def astep(self, _): self.missing_data, self.sum_trees, self.m, + self.response, self.normal, self.shape, ): @@ -393,6 +398,7 @@ def grow_tree( missing_data, sum_trees, m, + response, normal, kfactor, shape, @@ -402,8 +408,10 @@ def grow_tree( index_selected_predictor = ssv.rvs() selected_predictor = available_predictors[index_selected_predictor] - available_splitting_values = X[idx_data_points, selected_predictor] - split_value = get_split_value(available_splitting_values, idx_data_points, missing_data) + idx_data_points, available_splitting_values = filter_missing_values( + X[idx_data_points, selected_predictor], idx_data_points, missing_data + ) + split_value = get_split_value(available_splitting_values) if split_value is None: return None @@ -415,18 +423,24 @@ def grow_tree( get_idx_right_child(index_leaf_node), ) + if response == "mix": + response = "linear" if np.random.random() >= 0.5 else "constant" + for idx in range(2): idx_data_point = new_idx_data_points[idx] - node_value = draw_leaf_value( - sum_trees[:, idx_data_point], - m, - normal.rvs() * kfactor, - shape, + node_value, linear_params = draw_leaf_value( + y_mu_pred=sum_trees[:, idx_data_point], + x_mu=X[idx_data_point, selected_predictor], + m=m, + norm=normal.rvs() * kfactor, + shape=shape, + response=response, ) new_node = Node.new_leaf_node( value=node_value, idx_data_points=idx_data_point, + linear_params=linear_params, ) tree.set_node(current_node_children[idx], new_node) @@ -440,39 +454,52 @@ def get_new_idx_data_points(available_splitting_values, split_value, idx_data_po return idx_data_points[split_idx], idx_data_points[~split_idx] -def get_split_value(available_splitting_values, idx_data_points, missing_data): +def filter_missing_values(available_splitting_values, idx_data_points, missing_data): if missing_data: - idx_data_points = idx_data_points[~np.isnan(available_splitting_values)] - available_splitting_values = available_splitting_values[ - ~np.isnan(available_splitting_values) - ] + mask = ~np.isnan(available_splitting_values) + idx_data_points = idx_data_points[mask] + available_splitting_values = available_splitting_values[mask] + return idx_data_points, available_splitting_values + +def get_split_value(available_splitting_values): split_value = None if available_splitting_values.size > 0: idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) split_value = available_splitting_values[idx_selected_splitting_values] - return split_value -@njit -def draw_leaf_value(y_mu_pred, m, norm, shape): +def draw_leaf_value( + y_mu_pred: npt.NDArray[np.float_], + x_mu: npt.NDArray[np.float_], + m: int, + norm: npt.NDArray[np.float_], + shape: int, + response: str, +) -> Tuple[npt.NDArray[np.float_], Optional[npt.NDArray[np.float_]]]: """Draw Gaussian distributed leaf values.""" + linear_params = None + mu_mean = np.empty(shape) if y_mu_pred.size == 0: - return np.zeros(shape) + return np.zeros(shape), linear_params if y_mu_pred.size == 1: mu_mean = np.full(shape, y_mu_pred.item() / m) else: - mu_mean = fast_mean(y_mu_pred) / m + if response == "constant": + mu_mean = fast_mean(y_mu_pred) / m + if response == "linear": + y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred) + mu_mean = y_fit / m - return norm + mu_mean + draw = norm + mu_mean + return draw, linear_params @njit -def fast_mean(ari): +def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]: """Use Numba to speed up the computation of the mean.""" - if ari.ndim == 1: count = ari.shape[0] suma = 0 @@ -488,6 +515,31 @@ def fast_mean(ari): return res / count +@njit +def fast_linear_fit( + x: npt.NDArray[np.float_], y: npt.NDArray[np.float_] +) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]: + n = len(x) + + xbar = np.sum(x) / n + ybar = np.sum(y, axis=1) / n + + x_diff = x - xbar + y_diff = y - np.expand_dims(ybar, axis=1) + + x_var = np.dot(x_diff, x_diff.T) + + if x_var == 0: + b = np.zeros(y.shape[0]) + else: + b = np.dot(x_diff, y_diff.T) / x_var + + a = ybar - b * xbar + + y_fit = np.expand_dims(a, axis=1) + np.expand_dims(b, axis=1) * x + return y_fit.T, [a, b] + + def discrete_uniform_sampler(upper_value): """Draw from the uniform distribution with bounds [0, upper_value). diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 463e01b..9507965 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -25,29 +25,43 @@ class Node: Attributes ---------- - value : float + value : npt.NDArray[np.float_] idx_data_points : Optional[npt.NDArray[np.int_]] - idx_split_variable : Optional[npt.NDArray[np.int_]] + idx_split_variable : int + linear_params: Optional[List[float]] = None """ - __slots__ = "value", "idx_split_variable", "idx_data_points" + __slots__ = "value", "idx_split_variable", "idx_data_points", "linear_params" def __init__( self, - value: float = -1.0, + value: npt.NDArray[np.float_] = np.array([-1.0]), idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, + linear_params: Optional[List[float]] = None, ) -> None: self.value = value self.idx_data_points = idx_data_points self.idx_split_variable = idx_split_variable + self.linear_params = linear_params @classmethod - def new_leaf_node(cls, value: float, idx_data_points: Optional[npt.NDArray[np.int_]]) -> "Node": - return cls(value=value, idx_data_points=idx_data_points) + def new_leaf_node( + cls, + value: npt.NDArray[np.float_], + idx_data_points: Optional[npt.NDArray[np.int_]] = None, + idx_split_variable: int = -1, + linear_params: Optional[List[float]] = None, + ) -> "Node": + return cls( + value=value, + idx_data_points=idx_data_points, + idx_split_variable=idx_split_variable, + linear_params=linear_params, + ) @classmethod - def new_split_node(cls, split_value: float, idx_split_variable: int) -> "Node": + def new_split_node(cls, split_value: npt.NDArray[np.float_], idx_split_variable: int) -> "Node": return cls(value=split_value, idx_split_variable=idx_split_variable) def is_split_node(self) -> bool: @@ -115,7 +129,7 @@ def __init__( @classmethod def new_tree( cls, - leaf_node_value: float, + leaf_node_value: npt.NDArray[np.float_], idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, @@ -136,7 +150,12 @@ def __setitem__(self, index, node) -> None: def copy(self) -> "Tree": tree: Dict[int, Node] = { - k: Node(v.value, v.idx_data_points, v.idx_split_variable) + k: Node( + value=v.value, + idx_data_points=v.idx_data_points, + idx_split_variable=v.idx_split_variable, + linear_params=v.linear_params, + ) for k, v in self.tree_structure.items() } idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None @@ -151,7 +170,11 @@ def set_node(self, index: int, node: Node) -> None: self.idx_leaf_nodes.append(index) def grow_leaf_node( - self, current_node: Node, selected_predictor: int, split_value: float, index_leaf_node: int + self, + current_node: Node, + selected_predictor: int, + split_value: npt.NDArray[np.float_], + index_leaf_node: int, ) -> None: current_node.value = split_value current_node.idx_split_variable = selected_predictor @@ -161,7 +184,13 @@ def grow_leaf_node( def trim(self) -> "Tree": tree: Dict[int, Node] = { - k: Node(v.value, None, v.idx_split_variable) for k, v in self.tree_structure.items() + k: Node( + value=v.value, + idx_data_points=None, + idx_split_variable=v.idx_split_variable, + linear_params=v.linear_params, + ) + for k, v in self.tree_structure.items() } return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1])) @@ -176,11 +205,11 @@ def _predict(self) -> npt.NDArray[np.float_]: if self.idx_leaf_nodes is not None: for node_index in self.idx_leaf_nodes: leaf_node = self.get_node(node_index) - output[leaf_node.idx_data_points] = leaf_node.value + output[leaf_node.idx_data_points] = leaf_node.value.squeeze() return output.T def predict( - self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None + self, x: npt.NDArray[np.float_], m: int, excluded: Optional[List[int]] = None ) -> npt.NDArray[np.float_]: """ Predict output of tree for an (un)observed point x. @@ -189,6 +218,8 @@ def predict( ---------- x : npt.NDArray[np.float_] Unobserved point + m : int + Number of trees excluded: Optional[List[int]] Indexes of the variables to exclude when computing predictions @@ -199,12 +230,14 @@ def predict( """ if excluded is None: excluded = [] - return self._traverse_tree(x, 0, excluded) + return self._traverse_tree(x=x, m=m, node_index=0, split_variable=-1, excluded=excluded) def _traverse_tree( self, x: npt.NDArray[np.float_], + m: int, node_index: int, + split_variable: int = -1, excluded: Optional[List[int]] = None, ) -> npt.NDArray[np.float_]: """ @@ -214,8 +247,12 @@ def _traverse_tree( ---------- x : npt.NDArray[np.float_] Unobserved point + m : int + Number of trees node_index : int Index of the node to start the traversal from + split_variable : int + Index of the variable used to split the node excluded: Optional[List[int]] Indexes of the variables to exclude when computing predictions @@ -224,12 +261,19 @@ def _traverse_tree( npt.NDArray[np.float_] Leaf node value or mean of leaf node values """ - current_node: Node = self.get_node(node_index) + current_node = self.get_node(node_index) if current_node.is_leaf_node(): - return np.array(current_node.value) + if current_node.linear_params is None: + return np.array(current_node.value) + + x = x[split_variable].item() + y_x = current_node.linear_params[0] + current_node.linear_params[1] * x + return np.array(y_x / m) + + split_variable = current_node.idx_split_variable if excluded is not None and current_node.idx_split_variable in excluded: - leaf_values: List[float] = [] + leaf_values: List[npt.NDArray[np.float_]] = [] self._traverse_leaf_values(leaf_values, node_index) return np.mean(leaf_values, axis=0) @@ -237,15 +281,19 @@ def _traverse_tree( next_node = get_idx_left_child(node_index) else: next_node = get_idx_right_child(node_index) - return self._traverse_tree(x=x, node_index=next_node, excluded=excluded) + return self._traverse_tree( + x=x, m=m, node_index=next_node, split_variable=split_variable, excluded=excluded + ) - def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None: + def _traverse_leaf_values( + self, leaf_values: List[npt.NDArray[np.float_]], node_index: int + ) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- - leaf_values : List[float] + leaf_values : List[npt.NDArray[np.float_]] node_index : int """ node = self.get_node(node_index) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 45b9a75..60e6c7d 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -21,6 +21,7 @@ def _sample_posterior( all_trees: List[List[Tree]], X: TensorLike, + m: int, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None, excluded: Optional[npt.NDArray[np.int_]] = None, @@ -35,6 +36,8 @@ def _sample_posterior( X : tensor-like A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for out-of-sample predictions. + m : int + Number of trees rng : NumPy RandomGenerator size : int or tuple Number of samples. @@ -57,13 +60,13 @@ def _sample_posterior( flatten_size *= s idx = rng.integers(0, len(stacked_trees), size=flatten_size) - shape = stacked_trees[0][0].predict(X[0]).size + shape = stacked_trees[0][0].predict(x=X[0], m=m).size pred = np.zeros((flatten_size, X.shape[0], shape)) for ind, p in enumerate(pred): for tree in stacked_trees[idx[ind]]: - p += np.vstack([tree.predict(x, excluded) for x in X]) + p += np.vstack([tree.predict(x=x, m=m, excluded=excluded) for x in X]) pred.reshape((*size_iter, shape, -1)) return pred @@ -220,6 +223,8 @@ def plot_dependence( ------- axes: matplotlib axes """ + m: int = bartrv.owner.op.m + if kind not in ["pdp", "ice"]: raise ValueError(f"kind={kind} is not suported. Available option are 'pdp' or 'ice'") @@ -294,7 +299,7 @@ def plot_dependence( new_X[:, indices_mi] = X[:, indices_mi] new_X[:, i] = x_i y_pred.append( - np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 1) + np.mean(_sample_posterior(all_trees, X=new_X, m=m, rng=rng, size=samples), 1) ) new_x_target.append(new_x_i) else: @@ -302,7 +307,7 @@ def plot_dependence( new_X = X[idx_s] new_X[:, indices_mi] = X[:, indices_mi][instance] y_pred.append( - np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 0) + np.mean(_sample_posterior(all_trees, X=new_X, m=m, rng=rng, size=samples), 0) ) new_x_target.append(new_X[:, i]) y_mins.append(np.min(y_pred)) @@ -445,6 +450,8 @@ def plot_variable_importance( """ _, axes = plt.subplots(2, 1, figsize=figsize) + m: int = bartrv.owner.op.m + if hasattr(X, "columns") and hasattr(X, "values"): labels = X.columns X = X.values @@ -474,13 +481,13 @@ def plot_variable_importance( all_trees = bartrv.owner.op.all_trees - predicted_all = _sample_posterior(all_trees, X=X, rng=rng, size=samples, excluded=None) + predicted_all = _sample_posterior(all_trees, X=X, m=m, rng=rng, size=samples, excluded=None) ev_mean = np.zeros(len(var_imp)) ev_hdi = np.zeros((len(var_imp), 2)) for idx, subset in enumerate(subsets): predicted_subset = _sample_posterior( - all_trees=all_trees, X=X, rng=rng, size=samples, excluded=subset + all_trees=all_trees, X=X, m=m, rng=rng, size=samples, excluded=subset ) pearson = np.zeros(samples) for j in range(samples): diff --git a/tests/test_bart.py b/tests/test_bart.py index d8a3a27..3da9a2f 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -40,13 +40,18 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): assert np.isfinite(logp_moment) -def test_bart_vi(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], +) +def test_bart_vi(response): X = np.random.normal(0, 1, size=(250, 3)) Y = np.random.normal(0, 1, size=250) X[:, 0] = np.random.normal(Y, 0.1) with pm.Model() as model: - mu = pmb.BART("mu", X, Y, m=10) + mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(random_seed=3415) @@ -60,25 +65,35 @@ def test_bart_vi(): assert_almost_equal(var_imp.sum(), 1) -def test_missing_data(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], +) +def test_missing_data(response): X = np.random.normal(0, 1, size=(50, 2)) Y = np.random.normal(0, 1, size=50) X[10:20, 0] = np.nan with pm.Model() as model: - mu = pmb.BART("mu", X, Y, m=10) + mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415) -def test_shared_variable(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], +) +def test_shared_variable(response): X = np.random.normal(0, 1, size=(50, 2)) Y = np.random.normal(0, 1, size=50) with pm.Model() as model: data_X = pm.MutableData("data_X", X) - mu = pmb.BART("mu", data_X, Y, m=2) + mu = pmb.BART("mu", data_X, Y, m=2, response=response) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape) idata = pm.sample(tune=100, draws=100, chains=2, random_seed=3415) @@ -90,12 +105,17 @@ def test_shared_variable(): assert ppc2.posterior_predictive["y"].shape == (2, 100, 3) -def test_shape(): +@pytest.mark.parametrize( + argnames="response", + argvalues=["constant", "linear"], + ids=["constant", "linear-response"], +) +def test_shape(response): X = np.random.normal(0, 1, size=(250, 3)) Y = np.random.normal(0, 1, size=250) with pm.Model() as model: - w = pmb.BART("w", X, Y, m=2, shape=(2, 250)) + w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250)) y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y) idata = pm.sample(random_seed=3415) @@ -119,9 +139,13 @@ class TestUtils: def test_sample_posterior(self): all_trees = self.mu.owner.op.all_trees rng = np.random.default_rng(3) - pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) + pred_all = pmb.utils._sample_posterior( + all_trees, X=self.X, m=self.mu.owner.op.m, rng=rng, size=2 + ) rng = np.random.default_rng(3) - pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) + pred_first = pmb.utils._sample_posterior( + all_trees, X=self.X[:10], m=self.mu.owner.op.m, rng=rng + ) assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50, 1) diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index ec9fbac..841b8be 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -1,11 +1,16 @@ from unittest import TestCase - +import pytest import numpy as np import pymc as pm import pymc_bart as pmb -from pymc_bart.pgbart import (NormalSampler, UniformSampler, - discrete_uniform_sampler, fast_mean) +from pymc_bart.pgbart import ( + NormalSampler, + UniformSampler, + discrete_uniform_sampler, + fast_mean, + fast_linear_fit, +) class TestSystematic(TestCase): @@ -24,7 +29,7 @@ def test_systematic(self): indices = step.systematic(normalized_weights) self.assertEqual(len(indices), len(normalized_weights)) - self.assertEqual(indices.dtype, np.int) + self.assertEqual(indices.dtype, np.int_) self.assertTrue(all(i >= 0 and i < len(normalized_weights) for i in indices)) normalized_weights = np.array([0, 0.25, 0.75]) @@ -40,6 +45,23 @@ def test_fast_mean(): np.testing.assert_array_almost_equal(fast_mean(values), np.mean(values, 1)) +@pytest.mark.parametrize( + argnames="x,y,a_expected, b_expected", + argvalues=[ + (np.array([1, 2, 3, 4, 5]), np.array([[1, 2, 3, 4, 5]]), 0.0, 1.0), + (np.array([1, 2, 3, 4, 5]), np.array([[1, 1, 1, 1, 1]]), 1.0, 0.0), + ], + ids=["1d-id", "1d-const"], +) +def test_fast_linear_fit(x, y, a_expected, b_expected): + y_fit, linear_params = fast_linear_fit(x, y) + assert linear_params[0] == a_expected + assert linear_params[1] == b_expected + np.testing.assert_almost_equal( + actual=y_fit, desired=np.atleast_2d(a_expected + x * b_expected).T + ) + + def test_discrete_uniform(): sample = discrete_uniform_sampler(7) assert isinstance(sample, int)