diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 2b3a85e..e8b514f 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -40,7 +40,7 @@ class BARTRV(RandomVariable): ndims_params: List[int] = [2, 1, 0, 0, 0, 1] dtype: str = "floatX" _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") - all_trees = List[List[Tree]] + all_trees = List[List[List[Tree]]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): return dist_params[0].shape[:1] @@ -96,6 +96,15 @@ class BART(Distribution): List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. + shape: : Optional[Tuple], default None + Specify the output shape. If shape is different from (len(X)) (the default), train a + separate tree for each value in other dimensions. + separate_trees : Optional[bool], default False + When training multiple trees (by setting a shape parameter), the default behavior is to + learn a joint tree structure and only have different leaf values for each. + This flag forces a fully separate tree structure to be trained instead. + This is unnecessary in many cases and is considerably slower, multiplying + run-time roughly by number of dimensions. Notes ----- @@ -115,6 +124,7 @@ def __new__( response: str = "constant", split_prior: Optional[List[float]] = None, split_rules: Optional[SplitRule] = None, + separate_trees: Optional[bool] = False, **kwargs, ): manager = Manager() @@ -141,6 +151,7 @@ def __new__( beta=beta, split_prior=split_prior, split_rules=split_rules, + separate_trees=separate_trees, ), )() diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index b49900e..23c71ef 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -134,9 +134,17 @@ def __init__( self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m self.response = self.bart.response + shape = initial_values[value_bart.name].shape + self.shape = 1 if len(shape) == 1 else shape[0] + # Set trees_shape (dim for separate tree structures) + # and leaves_shape (dim for leaf node values) + # One of the two is always one, the other equal to self.shape + self.trees_shape = self.shape if self.bart.separate_trees else 1 + self.leaves_shape = self.shape if not self.bart.separate_trees else 1 + if self.bart.split_prior: self.alpha_vec = self.bart.split_prior else: @@ -153,27 +161,31 @@ def __init__( self.available_predictors = list(range(self.num_variates)) # if data is binary + self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape)) + y_unique = np.unique(self.bart.Y) if y_unique.size == 2 and np.all(y_unique == [0, 1]): - self.leaf_sd = 3 / self.m**0.5 + self.leaf_sd *= 3 / self.m**0.5 else: - self.leaf_sd = self.bart.Y.std() / self.m**0.5 + self.leaf_sd *= self.bart.Y.std() / self.m**0.5 - self.running_sd = RunningSd(shape) + self.running_sd = [ + RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape) + ] - self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype( - config.floatX - ) + self.sum_trees = np.full( + (self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean + ).astype(config.floatX) self.sum_trees_noi = self.sum_trees - init_mean self.a_tree = Tree.new_tree( leaf_node_value=init_mean / self.m, idx_data_points=np.arange(self.num_observations, dtype="int32"), num_observations=self.num_observations, - shape=self.shape, + shape=self.leaves_shape, split_rules=self.split_rules, ) - self.normal = NormalSampler(1, self.shape) + self.normal = NormalSampler(1, self.leaves_shape) self.uniform = UniformSampler(0, 1) self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha, self.bart.beta) self.ssv = SampleSplittingVariable(self.alpha_vec) @@ -188,8 +200,10 @@ def __init__( self.indices = list(range(1, num_particles)) shared = make_shared_replacements(initial_values, vars, model) self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared) - self.all_particles = [ParticleTree(self.a_tree) for _ in range(self.m)] - self.all_trees = np.array([p.tree for p in self.all_particles]) + self.all_particles = [ + [ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape) + ] + self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles]) self.lower = 0 self.iter = 0 super().__init__(vars, shared) @@ -201,72 +215,75 @@ def astep(self, _): tree_ids = range(self.lower, upper) self.lower = upper if upper < self.m else 0 - for tree_id in tree_ids: - self.iter += 1 - # Compute the sum of trees without the old tree that we are attempting to replace - self.sum_trees_noi = self.sum_trees - self.all_particles[tree_id].tree._predict() - # Generate an initial set of particles - # at the end we return one of these particles as the new tree - particles = self.init_particles(tree_id) - - while True: - # Sample each particle (try to grow each tree), except for the first one - stop_growing = True - for p in particles[1:]: - if p.sample_tree( - self.ssv, - self.available_predictors, - self.prior_prob_leaf_node, - self.X, - self.missing_data, - self.sum_trees, - self.leaf_sd, - self.m, - self.response, - self.normal, - self.shape, - ): - self.update_weight(p) - if p.expansion_nodes: - stop_growing = False - if stop_growing: - break - - # Normalize weights - normalized_weights = self.normalize(particles[1:]) - - # Resample - particles = self.resample(particles, normalized_weights) - - normalized_weights = self.normalize(particles) - # Get the new particle and associated tree - self.all_particles[tree_id], new_tree = self.get_particle_tree( - particles, normalized_weights - ) - # Update the sum of trees - new = new_tree._predict() - self.sum_trees = self.sum_trees_noi + new - # To reduce memory usage, we trim the tree - self.all_trees[tree_id] = new_tree.trim() - - if self.tune: - # Update the splitting variable and the splitting variable sampler - if self.iter > self.m: - self.ssv = SampleSplittingVariable(self.alpha_vec) - - for index in new_tree.get_split_variables(): - self.alpha_vec[index] += 1 - - # update standard deviation at leaf nodes - if self.iter > 2: - self.leaf_sd = self.running_sd.update(new) - else: - self.running_sd.update(new) + for odim in range(self.trees_shape): + for tree_id in tree_ids: + self.iter += 1 + # Compute the sum of trees without the old tree that we are attempting to replace + self.sum_trees_noi[odim] = ( + self.sum_trees[odim] - self.all_particles[odim][tree_id].tree._predict() + ) + # Generate an initial set of particles + # at the end we return one of these particles as the new tree + particles = self.init_particles(tree_id, odim) + + while True: + # Sample each particle (try to grow each tree), except for the first one + stop_growing = True + for p in particles[1:]: + if p.sample_tree( + self.ssv, + self.available_predictors, + self.prior_prob_leaf_node, + self.X, + self.missing_data, + self.sum_trees[odim], + self.leaf_sd[odim], + self.m, + self.response, + self.normal, + self.leaves_shape, + ): + self.update_weight(p, odim) + if p.expansion_nodes: + stop_growing = False + if stop_growing: + break + + # Normalize weights + normalized_weights = self.normalize(particles[1:]) + + # Resample + particles = self.resample(particles, normalized_weights) + + normalized_weights = self.normalize(particles) + # Get the new particle and associated tree + self.all_particles[odim][tree_id], new_tree = self.get_particle_tree( + particles, normalized_weights + ) + # Update the sum of trees + new = new_tree._predict() + self.sum_trees[odim] = self.sum_trees_noi[odim] + new + # To reduce memory usage, we trim the tree + self.all_trees[odim][tree_id] = new_tree.trim() + + if self.tune: + # Update the splitting variable and the splitting variable sampler + if self.iter > self.m: + self.ssv = SampleSplittingVariable(self.alpha_vec) + + for index in new_tree.get_split_variables(): + self.alpha_vec[index] += 1 + + # update standard deviation at leaf nodes + if self.iter > 2: + self.leaf_sd[odim] = self.running_sd[odim].update(new) + else: + self.running_sd[odim].update(new) - else: - # update the variable inclusion - for index in new_tree.get_split_variables(): - variable_inclusion[index] += 1 + else: + # update the variable inclusion + for index in new_tree.get_split_variables(): + variable_inclusion[index] += 1 if not self.tune: self.bart.all_trees.append(self.all_trees) @@ -331,23 +348,27 @@ def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[ single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw return inverse_cdf(single_uniform, normalized_weights) - def init_particles(self, tree_id: int) -> List[ParticleTree]: + def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]: """Initialize particles.""" - p0: ParticleTree = self.all_particles[tree_id] + p0: ParticleTree = self.all_particles[odim][tree_id] # The old tree does not grow so we update the weight only once - self.update_weight(p0) + self.update_weight(p0, odim) particles: List[ParticleTree] = [p0] particles.extend(ParticleTree(self.a_tree) for _ in self.indices) return particles - def update_weight(self, particle: ParticleTree) -> None: + def update_weight(self, particle: ParticleTree, odim: int) -> None: """ Update the weight of a particle. """ - new_likelihood = self.likelihood_logp( - (self.sum_trees_noi + particle.tree._predict()).flatten() + + delta = ( + np.identity(self.trees_shape)[odim][:, None, None] + * particle.tree._predict()[None, :, :] ) + + new_likelihood = self.likelihood_logp((self.sum_trees_noi + delta).flatten()) particle.log_weight = new_likelihood @staticmethod diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 07d243a..4833100 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -147,7 +147,7 @@ def new_tree( ) }, idx_leaf_nodes=[0], - output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(), + output=np.zeros((num_observations, shape)).astype(config.floatX), split_rules=split_rules, ) @@ -226,7 +226,7 @@ def _predict(self) -> npt.NDArray[np.float_]: if self.idx_leaf_nodes is not None: for node_index in self.idx_leaf_nodes: leaf_node = self.get_node(node_index) - output[leaf_node.idx_data_points] = leaf_node.value.squeeze() + output[leaf_node.idx_data_points] = leaf_node.value return output.T def predict( diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index d6a16eb..250fc51 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -44,11 +44,12 @@ def _sample_posterior( Indexes of the variables to exclude when computing predictions """ stacked_trees = all_trees + if isinstance(X, Variable): X = X.eval() if size is None: - size_iter: Union[List, Tuple] = () + size_iter: Union[List, Tuple] = (1,) elif isinstance(size, int): size_iter = [size] else: @@ -60,13 +61,18 @@ def _sample_posterior( idx = rng.integers(0, len(stacked_trees), size=flatten_size) - pred = np.zeros((flatten_size, X.shape[0], shape)) + trees_shape = len(stacked_trees[0]) + leaves_shape = shape // trees_shape + + pred = np.zeros((flatten_size, trees_shape, leaves_shape, X.shape[0])) for ind, p in enumerate(pred): - for tree in stacked_trees[idx[ind]]: - p += tree.predict(x=X, excluded=excluded, shape=shape).T - pred.reshape((*size_iter, shape, -1)) - return pred + for odim, odim_trees in enumerate(stacked_trees[idx[ind]]): + for tree in odim_trees: + p[odim] += tree.predict(x=X, excluded=excluded, shape=leaves_shape) + + # pred.reshape((*size_iter, shape, -1)) + return pred.transpose((0, 3, 1, 2)).reshape((*size_iter, -1, shape)) def plot_convergence( diff --git a/tests/test_bart.py b/tests/test_bart.py index 69815bf..43496ee 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -217,3 +217,30 @@ def test_bart_moment(size, expected): with pm.Model() as model: pmb.BART("x", X=X, Y=Y, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + argnames="separate_trees,split_rule", + argvalues=[ + (False,pmb.ContinuousSplitRule), + (False,pmb.OneHotSplitRule), + (False,pmb.SubsetSplitRule), + (True,pmb.ContinuousSplitRule) + ], + ids=["continuous", "one-hot", "subset", "separate-trees"], +) +def test_categorical_model(separate_trees,split_rule): + + Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]) + X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1) + + with pm.Model() as model: + lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9), + split_rules=[split_rule]*5, + separate_trees=separate_trees) + y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y) + idata = pm.sample(random_seed=3415, tune=300, draws=300) + idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True) + + # Fit should be good enough so right category is selected over 50% of time + assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()