From 0167c88b74e744a2ae13b83901601e2cdc994a90 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 30 Nov 2021 09:50:45 +0200 Subject: [PATCH 1/3] improves sampling by redrawing leafs and increasing particles --- pymc/bart/pgbart.py | 395 ++++++++++++++++++++++---------------------- pymc/bart/tree.py | 46 +----- 2 files changed, 200 insertions(+), 241 deletions(-) diff --git a/pymc/bart/pgbart.py b/pymc/bart/pgbart.py index 3bc7ac0a25..99b6f5fac7 100644 --- a/pymc/bart/pgbart.py +++ b/pymc/bart/pgbart.py @@ -15,7 +15,6 @@ import logging from copy import copy -from typing import Any, Dict, List, Tuple import aesara import numpy as np @@ -25,70 +24,12 @@ from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements from pymc.bart.bart import BARTRV from pymc.bart.tree import LeafNode, SplitNode, Tree -from pymc.blocking import RaveledVars from pymc.model import modelcontext from pymc.step_methods.arraystep import ArrayStepShared, Competence _log = logging.getLogger("pymc") -class ParticleTree: - """ - Particle tree - """ - - def __init__(self, tree, log_weight, likelihood): - self.tree = tree.copy() # keeps the tree that we care at the moment - self.expansion_nodes = [0] - self.log_weight = log_weight - self.old_likelihood_logp = likelihood - self.used_variates = [] - - def sample_tree_sequential( - self, - ssv, - available_predictors, - prior_prob_leaf_node, - X, - missing_data, - sum_trees_output, - mean, - linear_fit, - m, - normal, - mu_std, - response, - ): - tree_grew = False - if self.expansion_nodes: - index_leaf_node = self.expansion_nodes.pop(0) - # Probability that this node will remain a leaf node - prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] - - if prob_leaf < np.random.random(): - tree_grew, index_selected_predictor = grow_tree( - self.tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees_output, - mean, - linear_fit, - m, - normal, - mu_std, - response, - ) - if tree_grew: - new_indexes = self.tree.idx_leaf_nodes[-2:] - self.expansion_nodes.extend(new_indexes) - self.used_variates.append(index_selected_predictor) - - return tree_grew - - class PGBART(ArrayStepShared): """ Particle Gibss BART sampling step @@ -98,26 +39,15 @@ class PGBART(ArrayStepShared): vars: list List of value variables for sampler num_particles : int - Number of particles for the conditional SMC sampler. Defaults to 10 + Number of particles for the conditional SMC sampler. Defaults to 40 max_stages : int Maximum number of iterations of the conditional SMC sampler. Defaults to 100. batch : int or tuple Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees - during tuning and 20% after tuning. If a tuple is passed the first element is the batch size + during tuning and after tuning. If a tuple is passed the first element is the batch size during tuning and the second the batch size after tuning. model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). - - Note - ---- - This sampler is inspired by the [Lakshminarayanan2015] Particle Gibbs sampler, but introduces - several changes. The changes will be properly documented soon. - - References - ---------- - .. [Lakshminarayanan2015] Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015), - Particle Gibbs for Bayesian Additive Regression Trees. - ArviX, `link `__ """ name = "bartsampler" @@ -125,7 +55,7 @@ class PGBART(ArrayStepShared): generates_stats = True stats_dtypes = [{"variable_inclusion": np.ndarray, "bart_trees": np.ndarray}] - def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None): + def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", model=None): _log.warning("BART is experimental. Use with caution.") model = modelcontext(model) initial_values = model.recompute_initial_point() @@ -156,9 +86,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.num_variates = self.X.shape[1] self.available_predictors = list(range(self.num_variates)) - sum_trees_output = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX) + self.sum_trees = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX) self.a_tree = Tree.init_tree( - tree_id=0, leaf_node_value=self.init_mean / self.m, idx_data_points=np.arange(self.num_observations, dtype="int32"), m=self.m, @@ -173,7 +102,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.tune = True if batch == "auto": - self.batch = (max(1, int(self.m * 0.1)), max(1, int(self.m * 0.2))) + batch = max(1, int(self.m * 0.1)) + self.batch = (batch, batch) else: if isinstance(batch, (tuple, list)): self.batch = batch @@ -181,54 +111,56 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.batch = (batch, batch) self.log_num_particles = np.log(num_particles) - self.indices = list(range(1, num_particles)) + self.indices = list(range(2, num_particles)) self.len_indices = len(self.indices) self.max_stages = max_stages shared = make_shared_replacements(initial_values, vars, model) self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared) - self.init_likelihood = self.likelihood_logp(sum_trees_output) - self.init_log_weight = self.init_likelihood - self.log_num_particles self.all_particles = [] for i in range(self.m): - self.a_tree.tree_id = i - self.a_tree.leaf_node_value = ( - self.init_mean / self.m + self.normal.random() * self.mu_std, - ) - p = ParticleTree( - self.a_tree, - self.init_log_weight, - self.init_likelihood, - ) + self.a_tree.leaf_node_value = self.init_mean / self.m + p = ParticleTree(self.a_tree) self.all_particles.append(p) self.all_trees = np.array([p.tree for p in self.all_particles]) super().__init__(vars, shared) - def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: - point_map_info = q.point_map_info - sum_trees_output = q.data + def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") - tree_ids = np.random.randint(0, self.m, size=self.batch[~self.tune]) + tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune]) for tree_id in tree_ids: # Generate an initial set of SMC particles # at the end of the algorithm we return one of these particles as the new tree particles = self.init_particles(tree_id) - # Compute the sum of trees without the tree we are attempting to replace - self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output() - - # The old tree is not growing so we update the weights only once. - self.update_weight(particles[0], new=True) - for t in range(self.max_stages): - # Sample each particle (try to grow each tree), except for the first one. - for p in particles[1:]: - tree_grew = p.sample_tree_sequential( + # Compute the sum of trees without the old tree, that we are attempting to replace + self.sum_trees_noi = self.sum_trees - particles[0].tree.predict_output() + # Resample leaf values for particle 1 which is a copy of the old tree + particles[1].sample_leafs( + self.sum_trees, + self.X, + self.mean, + self.linear_fit, + self.m, + self.normal, + self.mu_std, + self.response, + ) + + # The old tree and the one with new leafs do not grow so we update the weights only once + self.update_weight(particles[0], old=True) + self.update_weight(particles[1], old=True) + for _ in range(self.max_stages): + # Sample each particle (try to grow each tree), except for the first two + stop_growing = True + for p in particles[2:]: + tree_grew = p.sample_tree( self.ssv, self.available_predictors, self.prior_prob_leaf_node, self.X, self.missing_data, - sum_trees_output, + self.sum_trees, self.mean, self.linear_fit, self.m, @@ -238,27 +170,24 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: ) if tree_grew: self.update_weight(p) + if p.expansion_nodes: + stop_growing = False + if stop_growing: + break # Normalize weights - W_t, normalized_weights = self.normalize(particles[1:]) + W_t, normalized_weights = self.normalize(particles[2:]) - # Resample all but first particle - re_n_w = normalized_weights - new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w) - particles[1:] = particles[new_indices] + # Resample all but first two particles + new_indices = np.random.choice( + self.indices, size=self.len_indices, p=normalized_weights + ) + particles[2:] = particles[new_indices] # Set the new weights - for p in particles[1:]: + for p in particles[2:]: p.log_weight = W_t - # Check if particles can keep growing, otherwise stop iterating - non_available_nodes_for_expansion = [] - for p in particles[1:]: - if p.expansion_nodes: - non_available_nodes_for_expansion.append(0) - if all(non_available_nodes_for_expansion): - break - - for p in particles[1:]: + for p in particles[2:]: p.log_weight = p.old_likelihood_logp _, normalized_weights = self.normalize(particles) @@ -268,7 +197,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: self.all_trees[tree_id] = new_tree new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles self.all_particles[tree_id] = new_particle - sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output() + self.sum_trees = self.sum_trees_noi + new_tree.predict_output() if self.tune: self.ssv = SampleSplittingVariable(self.alpha_vec) @@ -278,21 +207,10 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: for index in new_particle.used_variates: variable_inclusion[index] += 1 - stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)} - sum_trees_output = RaveledVars(sum_trees_output, point_map_info) - return sum_trees_output, [stats] + stats = {"variable_inclusion": variable_inclusion, "bart_trees": self.all_trees} + return self.sum_trees, [stats] - @staticmethod - def competence(var, has_grad): - """ - PGBART is only suitable for BART distributions - """ - dist = getattr(var.owner, "op", None) - if isinstance(dist, BARTRV): - return Competence.IDEAL - return Competence.INCOMPATIBLE - - def normalize(self, particles: List[ParticleTree]) -> Tuple[float, np.ndarray]: + def normalize(self, particles): """ Use logsumexp trick to get W_t and softmax to get normalized_weights """ @@ -313,38 +231,101 @@ def init_particles(self, tree_id: int) -> np.ndarray: Initialize particles """ p = self.all_particles[tree_id] - p.log_weight = self.init_log_weight - p.old_likelihood_logp = self.init_likelihood particles = [p] + particles.append(copy(p)) for _ in self.indices: - self.a_tree.tree_id = tree_id - particles.append( - ParticleTree( - self.a_tree, - self.init_log_weight, - self.init_likelihood, - ) - ) + particles.append(ParticleTree(self.a_tree)) return np.array(particles) - def update_weight(self, particle: List[ParticleTree], new=False) -> None: + def update_weight(self, particle, old=False): """ Update the weight of a particle Since the prior is used as the proposal,the weights are updated additively as the ratio of the new and old log-likelihoods. """ - new_likelihood = self.likelihood_logp( - self.sum_trees_output_noi + particle.tree.predict_output() - ) - if new: + new_likelihood = self.likelihood_logp(self.sum_trees_noi + particle.tree.predict_output()) + if old: particle.log_weight = new_likelihood + particle.old_likelihood_logp = new_likelihood else: particle.log_weight += new_likelihood - particle.old_likelihood_logp particle.old_likelihood_logp = new_likelihood + @staticmethod + def competence(var, has_grad): + """ + PGBART is only suitable for BART distributions + """ + dist = getattr(var.owner, "op", None) + if isinstance(dist, BARTRV): + return Competence.IDEAL + return Competence.INCOMPATIBLE + + +class ParticleTree: + """ + Particle tree + """ + + def __init__(self, tree): + self.tree = tree.copy() # keeps the tree that we care at the moment + self.expansion_nodes = [0] + self.log_weight = 0 + self.old_likelihood_logp = 0 + self.used_variates = [] + + def sample_tree( + self, + ssv, + available_predictors, + prior_prob_leaf_node, + X, + missing_data, + sum_trees, + mean, + linear_fit, + m, + normal, + mu_std, + response, + ): + tree_grew = False + if self.expansion_nodes: + index_leaf_node = self.expansion_nodes.pop(0) + # Probability that this node will remain a leaf node + prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] + + if prob_leaf < np.random.random(): + index_selected_predictor = grow_tree( + self.tree, + index_leaf_node, + ssv, + available_predictors, + X, + missing_data, + sum_trees, + mean, + linear_fit, + m, + normal, + mu_std, + response, + ) + if index_selected_predictor is not None: + new_indexes = self.tree.idx_leaf_nodes[-2:] + self.expansion_nodes.extend(new_indexes) + self.used_variates.append(index_selected_predictor) + tree_grew = True + + return tree_grew + + def sample_leafs(self, sum_trees, X, mean, linear_fit, m, normal, mu_std, response): + + sample_leaf_values(self.tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response) + class SampleSplittingVariable: def __init__(self, alpha_vec): @@ -396,7 +377,7 @@ def grow_tree( available_predictors, X, missing_data, - sum_trees_output, + sum_trees, mean, linear_fit, m, @@ -416,60 +397,78 @@ def grow_tree( ~np.isnan(available_splitting_values) ] - if available_splitting_values.size == 0: - return False, None + if available_splitting_values.size > 0: + idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) + split_value = available_splitting_values[idx_selected_splitting_values] - idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) - split_value = available_splitting_values[idx_selected_splitting_values] - new_split_node = SplitNode( - index=index_leaf_node, - idx_split_variable=selected_predictor, - split_value=split_value, - ) + new_idx_data_points = get_new_idx_data_points( + split_value, idx_data_points, selected_predictor, X + ) + current_node_children = ( + current_node.get_idx_left_child(), + current_node.get_idx_right_child(), + ) - left_node_idx_data_points, right_node_idx_data_points = get_new_idx_data_points( - split_value, idx_data_points, selected_predictor, X - ) + if response == "mix": + response = "linear" if np.random.random() >= 0.5 else "constant" + + new_nodes = [] + for idx in range(2): + idx_data_point = new_idx_data_points[idx] + node_value, node_linear_params = draw_leaf_value( + sum_trees[idx_data_point], + X[idx_data_point, selected_predictor], + mean, + linear_fit, + m, + normal, + mu_std, + response, + ) - if response == "mix": - response = "linear" if np.random.random() >= 0.5 else "constant" + new_node = LeafNode( + index=current_node_children[idx], + value=node_value, + idx_data_points=idx_data_point, + linear_params=node_linear_params, + ) + new_nodes.append(new_node) - left_node_value, left_node_linear_params = draw_leaf_value( - sum_trees_output[left_node_idx_data_points], - X[left_node_idx_data_points, selected_predictor], - mean, - linear_fit, - m, - normal, - mu_std, - response, - ) - right_node_value, right_node_linear_params = draw_leaf_value( - sum_trees_output[right_node_idx_data_points], - X[right_node_idx_data_points, selected_predictor], - mean, - linear_fit, - m, - normal, - mu_std, - response, - ) - - new_left_node = LeafNode( - index=current_node.get_idx_left_child(), - value=left_node_value, - idx_data_points=left_node_idx_data_points, - linear_params=left_node_linear_params, - ) - new_right_node = LeafNode( - index=current_node.get_idx_right_child(), - value=right_node_value, - idx_data_points=right_node_idx_data_points, - linear_params=right_node_linear_params, - ) - tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) - - return True, index_selected_predictor + new_split_node = SplitNode( + index=index_leaf_node, + idx_split_variable=selected_predictor, + split_value=split_value, + ) + + # update tree nodes and indexes + tree.delete_node(index_leaf_node) + tree.set_node(index_leaf_node, new_split_node) + tree.set_node(new_nodes[0].index, new_nodes[0]) + tree.set_node(new_nodes[1].index, new_nodes[1]) + + return index_selected_predictor + + +def sample_leaf_values(tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response): + + for idx in tree.idx_leaf_nodes: + if idx > 0: + leaf = tree[idx] + idx_data_points = leaf.idx_data_points + parent_node = tree[leaf.get_idx_parent_node()] + selected_predictor = parent_node.idx_split_variable + node_value, node_linear_params = draw_leaf_value( + sum_trees[idx_data_points], + X[idx_data_points, selected_predictor], + mean, + linear_fit, + m, + normal, + mu_std, + response, + ) + leaf.value = node_value + leaf.linear_params = node_linear_params def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X): diff --git a/pymc/bart/tree.py b/pymc/bart/tree.py index b982e80bb6..205de5dd8d 100644 --- a/pymc/bart/tree.py +++ b/pymc/bart/tree.py @@ -34,15 +34,8 @@ class Tree: The dictionary's keys are integers that represent the nodes position. The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes of the tree itself. - num_nodes : int - Total number of nodes. idx_leaf_nodes : list List with the index of the leaf nodes of the tree. - idx_prunable_split_nodes : list - List with the index of the prunable splitting nodes of the tree. A splitting node is - prunable if both its children are leaf nodes. - tree_id : int - Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART. num_observations : int Number of observations used to fit BART. m : int @@ -50,16 +43,12 @@ class Tree: Parameters ---------- - tree_id : int, optional num_observations : int, optional """ - def __init__(self, tree_id=0, num_observations=0, m=0): + def __init__(self, num_observations=0, m=0): self.tree_structure = {} - self.num_nodes = 0 self.idx_leaf_nodes = [] - self.idx_prunable_split_nodes = [] - self.tree_id = tree_id self.num_observations = num_observations self.m = m @@ -77,7 +66,6 @@ def get_node(self, index): def set_node(self, index, node): self.tree_structure[index] = node - self.num_nodes += 1 if isinstance(node, LeafNode): self.idx_leaf_nodes.append(index) @@ -86,7 +74,6 @@ def delete_node(self, index): if isinstance(current_node, LeafNode): self.idx_leaf_nodes.remove(index) del self.tree_structure[index] - self.num_nodes -= 1 def predict_output(self): output = np.zeros(self.num_observations) @@ -143,39 +130,12 @@ def _traverse_tree(self, x, node_index=0, split_variable=None): current_node, split_variable = self._traverse_tree(x, right_child, split_variable) return current_node, split_variable - def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node): - """ - Grow the tree from a particular node. - - Parameters - ---------- - index_leaf_node : int - new_split_node : SplitNode - new_left_node : LeafNode - new_right_node : LeafNode - """ - current_node = self.get_node(index_leaf_node) - - self.delete_node(index_leaf_node) - self.set_node(index_leaf_node, new_split_node) - self.set_node(new_left_node.index, new_left_node) - self.set_node(new_right_node.index, new_right_node) - - # The new SplitNode is a prunable node since it has both children. - self.idx_prunable_split_nodes.append(index_leaf_node) - # If the parent of the node from which the tree is growing was a prunable node, - # remove from the list since one of its children is a SplitNode now - parent_index = current_node.get_idx_parent_node() - if parent_index in self.idx_prunable_split_nodes: - self.idx_prunable_split_nodes.remove(parent_index) - @staticmethod - def init_tree(tree_id, leaf_node_value, idx_data_points, m): + def init_tree(leaf_node_value, idx_data_points, m): """ Parameters ---------- - tree_id leaf_node_value idx_data_points m : int @@ -185,7 +145,7 @@ def init_tree(tree_id, leaf_node_value, idx_data_points, m): ------- """ - new_tree = Tree(tree_id, len(idx_data_points), m) + new_tree = Tree(len(idx_data_points), m) new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) return new_tree From f023e32bf9439eff55c9bf433c70c6e291947c1a Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 30 Nov 2021 10:33:19 +0200 Subject: [PATCH 2/3] remove linear and mix response --- pymc/bart/bart.py | 5 --- pymc/bart/pgbart.py | 71 ++++++----------------------------------- pymc/bart/tree.py | 33 +++++++------------ pymc/tests/test_bart.py | 2 +- 4 files changed, 22 insertions(+), 89 deletions(-) diff --git a/pymc/bart/bart.py b/pymc/bart/bart.py index 7c4d09c5f2..93fd8e2f81 100644 --- a/pymc/bart/bart.py +++ b/pymc/bart/bart.py @@ -67,9 +67,6 @@ class BART(NoDistribution): k : float Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1 and 3. - response : str - How the leaf_node values are computed. Available options are ``constant`` (default), - ``linear`` or ``mix``. split_prior : array-like Each element of split_prior should be in the [0, 1] interval and the elements should sum to 1. Otherwise they will be normalized. @@ -84,7 +81,6 @@ def __new__( m=50, alpha=0.25, k=2, - response="constant", split_prior=None, **kwargs, ): @@ -103,7 +99,6 @@ def __new__( m=m, alpha=alpha, k=k, - response=response, split_prior=split_prior, ), )() diff --git a/pymc/bart/pgbart.py b/pymc/bart/pgbart.py index 99b6f5fac7..1486ddc191 100644 --- a/pymc/bart/pgbart.py +++ b/pymc/bart/pgbart.py @@ -68,7 +68,6 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo self.m = self.bart.m self.alpha = self.bart.alpha self.k = self.bart.k - self.response = self.bart.response self.alpha_vec = self.bart.split_prior if self.alpha_vec is None: self.alpha_vec = np.ones(self.X.shape[1]) @@ -90,10 +89,8 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo self.a_tree = Tree.init_tree( leaf_node_value=self.init_mean / self.m, idx_data_points=np.arange(self.num_observations, dtype="int32"), - m=self.m, ) self.mean = fast_mean() - self.linear_fit = fast_linear_fit() self.normal = NormalSampler() self.prior_prob_leaf_node = compute_prior_probability(self.alpha) @@ -140,11 +137,9 @@ def astep(self, _): self.sum_trees, self.X, self.mean, - self.linear_fit, self.m, self.normal, self.mu_std, - self.response, ) # The old tree and the one with new leafs do not grow so we update the weights only once @@ -162,11 +157,9 @@ def astep(self, _): self.missing_data, self.sum_trees, self.mean, - self.linear_fit, self.m, self.normal, self.mu_std, - self.response, ) if tree_grew: self.update_weight(p) @@ -286,11 +279,9 @@ def sample_tree( missing_data, sum_trees, mean, - linear_fit, m, normal, mu_std, - response, ): tree_grew = False if self.expansion_nodes: @@ -308,11 +299,9 @@ def sample_tree( missing_data, sum_trees, mean, - linear_fit, m, normal, mu_std, - response, ) if index_selected_predictor is not None: new_indexes = self.tree.idx_leaf_nodes[-2:] @@ -322,9 +311,9 @@ def sample_tree( return tree_grew - def sample_leafs(self, sum_trees, X, mean, linear_fit, m, normal, mu_std, response): + def sample_leafs(self, sum_trees, X, mean, m, normal, mu_std): - sample_leaf_values(self.tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response) + sample_leaf_values(self.tree, sum_trees, X, mean, m, normal, mu_std) class SampleSplittingVariable: @@ -379,11 +368,9 @@ def grow_tree( missing_data, sum_trees, mean, - linear_fit, m, normal, mu_std, - response, ): current_node = tree.get_node(index_leaf_node) idx_data_points = current_node.idx_data_points @@ -409,28 +396,22 @@ def grow_tree( current_node.get_idx_right_child(), ) - if response == "mix": - response = "linear" if np.random.random() >= 0.5 else "constant" - new_nodes = [] for idx in range(2): idx_data_point = new_idx_data_points[idx] - node_value, node_linear_params = draw_leaf_value( + node_value = draw_leaf_value( sum_trees[idx_data_point], X[idx_data_point, selected_predictor], mean, - linear_fit, m, normal, mu_std, - response, ) new_node = LeafNode( index=current_node_children[idx], value=node_value, idx_data_points=idx_data_point, - linear_params=node_linear_params, ) new_nodes.append(new_node) @@ -449,7 +430,7 @@ def grow_tree( return index_selected_predictor -def sample_leaf_values(tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response): +def sample_leaf_values(tree, sum_trees, X, mean, m, normal, mu_std): for idx in tree.idx_leaf_nodes: if idx > 0: @@ -457,18 +438,15 @@ def sample_leaf_values(tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, idx_data_points = leaf.idx_data_points parent_node = tree[leaf.get_idx_parent_node()] selected_predictor = parent_node.idx_split_variable - node_value, node_linear_params = draw_leaf_value( + node_value = draw_leaf_value( sum_trees[idx_data_points], X[idx_data_points, selected_predictor], mean, - linear_fit, m, normal, mu_std, - response, ) leaf.value = node_value - leaf.linear_params = node_linear_params def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X): @@ -480,24 +458,19 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X) return left_node_idx_data_points, right_node_idx_data_points -def draw_leaf_value(Y_mu_pred, X_mu, mean, linear_fit, m, normal, mu_std, response): +def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std): """Draw Gaussian distributed leaf values""" - linear_params = None if Y_mu_pred.size == 0: - return 0, linear_params + return 0 else: norm = normal.random() * mu_std if Y_mu_pred.size == 1: mu_mean = Y_mu_pred.item() / m - elif response == "constant": + else: mu_mean = mean(Y_mu_pred) / m - elif response == "linear": - Y_fit, linear_params = linear_fit(X_mu, Y_mu_pred) - mu_mean = Y_fit / m - linear_params[2] = norm draw = norm + mu_mean - return draw, linear_params + return draw def fast_mean(): @@ -518,32 +491,6 @@ def mean(a): return mean -def fast_linear_fit(): - """If available use Numba to speed up the computation of the linear fit""" - - def linear_fit(X, Y): - - n = len(Y) - xbar = np.sum(X) / n - ybar = np.sum(Y) / n - - den = X @ X - n * xbar ** 2 - if den > 1e-10: - b = (X @ Y - n * xbar * ybar) / den - else: - b = 0 - a = ybar - b * xbar - Y_fit = a + b * X - return Y_fit, [a, b, 0] - - try: - from numba import jit - - return jit(linear_fit) - except ImportError: - return linear_fit - - def discrete_uniform_sampler(upper_value): """Draw from the uniform distribution with bounds [0, upper_value). diff --git a/pymc/bart/tree.py b/pymc/bart/tree.py index 205de5dd8d..4705690d2d 100644 --- a/pymc/bart/tree.py +++ b/pymc/bart/tree.py @@ -46,11 +46,10 @@ class Tree: num_observations : int, optional """ - def __init__(self, num_observations=0, m=0): + def __init__(self, num_observations=0): self.tree_structure = {} self.idx_leaf_nodes = [] self.num_observations = num_observations - self.m = m def __getitem__(self, index): return self.get_node(index) @@ -97,16 +96,10 @@ def predict_out_of_sample(self, X): float Value of the leaf value where the unobserved point lies. """ - leaf_node, split_variable = self._traverse_tree(X, node_index=0) - linear_params = leaf_node.linear_params - if linear_params is None: - return leaf_node.value - else: - x = X[split_variable].item() - y_x = (linear_params[0] + linear_params[1] * x) / self.m - return y_x + linear_params[2] - - def _traverse_tree(self, x, node_index=0, split_variable=None): + leaf_node = self._traverse_tree(X, node_index=0) + return leaf_node.value + + def _traverse_tree(self, x, node_index=0): """ Traverse the tree starting from a particular node given an unobserved point. @@ -121,17 +114,16 @@ def _traverse_tree(self, x, node_index=0, split_variable=None): """ current_node = self.get_node(node_index) if isinstance(current_node, SplitNode): - split_variable = current_node.idx_split_variable - if x[split_variable] <= current_node.split_value: + if x[current_node.idx_split_variable] <= current_node.split_value: left_child = current_node.get_idx_left_child() - current_node, split_variable = self._traverse_tree(x, left_child, split_variable) + current_node = self._traverse_tree(x, left_child) else: right_child = current_node.get_idx_right_child() - current_node, split_variable = self._traverse_tree(x, right_child, split_variable) - return current_node, split_variable + current_node = self._traverse_tree(x, right_child) + return current_node @staticmethod - def init_tree(leaf_node_value, idx_data_points, m): + def init_tree(leaf_node_value, idx_data_points): """ Parameters @@ -145,7 +137,7 @@ def init_tree(leaf_node_value, idx_data_points, m): ------- """ - new_tree = Tree(len(idx_data_points), m) + new_tree = Tree(len(idx_data_points)) new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) return new_tree @@ -174,8 +166,7 @@ def __init__(self, index, idx_split_variable, split_value): class LeafNode(BaseNode): - def __init__(self, index, value, idx_data_points, linear_params=None): + def __init__(self, index, value, idx_data_points): super().__init__(index) self.value = value self.idx_data_points = idx_data_points - self.linear_params = linear_params diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 3db5543908..0c80cbb23f 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -67,7 +67,7 @@ class TestUtils: Y = np.random.normal(0, 1, size=50) with pm.Model() as model: - mu = pm.BART("mu", X, Y, m=10, response="mix") + mu = pm.BART("mu", X, Y, m=10) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(random_seed=3415) From 45d5afb87e92af197eb3083c72facbb42317ba8c Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 30 Nov 2021 11:40:00 +0200 Subject: [PATCH 3/3] update release notes --- RELEASE-NOTES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f3035fe825..660ab46820 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -93,9 +93,9 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)). - `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)). - New features for BART: - - Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044). - Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091). - Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)). + - Improve sampling, increase default number of particles [5229](https://github.com/pymc-devs/pymc3/pull/5229). - `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098) - ...