Skip to content

Allow training separate tree structures if training multiple trees #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class BARTRV(RandomVariable):
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = List[List[Tree]]
all_trees = List[List[List[Tree]]]

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return dist_params[0].shape[:1]
Expand Down Expand Up @@ -96,6 +96,15 @@ class BART(Distribution):
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
shape: : Optional[Tuple], default None
Specify the output shape. If shape is different from (len(X)) (the default), train a
separate tree for each value in other dimensions.
separate_trees : Optional[bool], default False
When training multiple trees (by setting a shape parameter), the default behavior is to
learn a joint tree structure and only have different leaf values for each.
This flag forces a fully separate tree structure to be trained instead.
This is unnecessary in many cases and is considerably slower, multiplying
run-time roughly by number of dimensions.

Notes
-----
Expand All @@ -115,6 +124,7 @@ def __new__(
response: str = "constant",
split_prior: Optional[List[float]] = None,
split_rules: Optional[SplitRule] = None,
separate_trees: Optional[bool] = False,
**kwargs,
):
manager = Manager()
Expand All @@ -141,6 +151,7 @@ def __new__(
beta=beta,
split_prior=split_prior,
split_rules=split_rules,
separate_trees=separate_trees,
),
)()

Expand Down
183 changes: 102 additions & 81 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,17 @@ def __init__(
self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.response = self.bart.response

shape = initial_values[value_bart.name].shape

self.shape = 1 if len(shape) == 1 else shape[0]

# Set trees_shape (dim for separate tree structures)
# and leaves_shape (dim for leaf node values)
# One of the two is always one, the other equal to self.shape
self.trees_shape = self.shape if self.bart.separate_trees else 1
self.leaves_shape = self.shape if not self.bart.separate_trees else 1

if self.bart.split_prior:
self.alpha_vec = self.bart.split_prior
else:
Expand All @@ -153,27 +161,31 @@ def __init__(
self.available_predictors = list(range(self.num_variates))

# if data is binary
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))

y_unique = np.unique(self.bart.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
self.leaf_sd = 3 / self.m**0.5
self.leaf_sd *= 3 / self.m**0.5
else:
self.leaf_sd = self.bart.Y.std() / self.m**0.5
self.leaf_sd *= self.bart.Y.std() / self.m**0.5

self.running_sd = RunningSd(shape)
self.running_sd = [
RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape)
]

self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
config.floatX
)
self.sum_trees = np.full(
(self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean
).astype(config.floatX)
self.sum_trees_noi = self.sum_trees - init_mean
self.a_tree = Tree.new_tree(
leaf_node_value=init_mean / self.m,
idx_data_points=np.arange(self.num_observations, dtype="int32"),
num_observations=self.num_observations,
shape=self.shape,
shape=self.leaves_shape,
split_rules=self.split_rules,
)

self.normal = NormalSampler(1, self.shape)
self.normal = NormalSampler(1, self.leaves_shape)
self.uniform = UniformSampler(0, 1)
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha, self.bart.beta)
self.ssv = SampleSplittingVariable(self.alpha_vec)
Expand All @@ -188,8 +200,10 @@ def __init__(
self.indices = list(range(1, num_particles))
shared = make_shared_replacements(initial_values, vars, model)
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
self.all_particles = [ParticleTree(self.a_tree) for _ in range(self.m)]
self.all_trees = np.array([p.tree for p in self.all_particles])
self.all_particles = [
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
]
self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles])
self.lower = 0
self.iter = 0
super().__init__(vars, shared)
Expand All @@ -201,72 +215,75 @@ def astep(self, _):
tree_ids = range(self.lower, upper)
self.lower = upper if upper < self.m else 0

for tree_id in tree_ids:
self.iter += 1
# Compute the sum of trees without the old tree that we are attempting to replace
self.sum_trees_noi = self.sum_trees - self.all_particles[tree_id].tree._predict()
# Generate an initial set of particles
# at the end we return one of these particles as the new tree
particles = self.init_particles(tree_id)

while True:
# Sample each particle (try to grow each tree), except for the first one
stop_growing = True
for p in particles[1:]:
if p.sample_tree(
self.ssv,
self.available_predictors,
self.prior_prob_leaf_node,
self.X,
self.missing_data,
self.sum_trees,
self.leaf_sd,
self.m,
self.response,
self.normal,
self.shape,
):
self.update_weight(p)
if p.expansion_nodes:
stop_growing = False
if stop_growing:
break

# Normalize weights
normalized_weights = self.normalize(particles[1:])

# Resample
particles = self.resample(particles, normalized_weights)

normalized_weights = self.normalize(particles)
# Get the new particle and associated tree
self.all_particles[tree_id], new_tree = self.get_particle_tree(
particles, normalized_weights
)
# Update the sum of trees
new = new_tree._predict()
self.sum_trees = self.sum_trees_noi + new
# To reduce memory usage, we trim the tree
self.all_trees[tree_id] = new_tree.trim()

if self.tune:
# Update the splitting variable and the splitting variable sampler
if self.iter > self.m:
self.ssv = SampleSplittingVariable(self.alpha_vec)

for index in new_tree.get_split_variables():
self.alpha_vec[index] += 1

# update standard deviation at leaf nodes
if self.iter > 2:
self.leaf_sd = self.running_sd.update(new)
else:
self.running_sd.update(new)
for odim in range(self.trees_shape):
for tree_id in tree_ids:
self.iter += 1
# Compute the sum of trees without the old tree that we are attempting to replace
self.sum_trees_noi[odim] = (
self.sum_trees[odim] - self.all_particles[odim][tree_id].tree._predict()
)
# Generate an initial set of particles
# at the end we return one of these particles as the new tree
particles = self.init_particles(tree_id, odim)

while True:
# Sample each particle (try to grow each tree), except for the first one
stop_growing = True
for p in particles[1:]:
if p.sample_tree(
self.ssv,
self.available_predictors,
self.prior_prob_leaf_node,
self.X,
self.missing_data,
self.sum_trees[odim],
self.leaf_sd[odim],
self.m,
self.response,
self.normal,
self.leaves_shape,
):
self.update_weight(p, odim)
if p.expansion_nodes:
stop_growing = False
if stop_growing:
break

# Normalize weights
normalized_weights = self.normalize(particles[1:])

# Resample
particles = self.resample(particles, normalized_weights)

normalized_weights = self.normalize(particles)
# Get the new particle and associated tree
self.all_particles[odim][tree_id], new_tree = self.get_particle_tree(
particles, normalized_weights
)
# Update the sum of trees
new = new_tree._predict()
self.sum_trees[odim] = self.sum_trees_noi[odim] + new
# To reduce memory usage, we trim the tree
self.all_trees[odim][tree_id] = new_tree.trim()

if self.tune:
# Update the splitting variable and the splitting variable sampler
if self.iter > self.m:
self.ssv = SampleSplittingVariable(self.alpha_vec)

for index in new_tree.get_split_variables():
self.alpha_vec[index] += 1

# update standard deviation at leaf nodes
if self.iter > 2:
self.leaf_sd[odim] = self.running_sd[odim].update(new)
else:
self.running_sd[odim].update(new)

else:
# update the variable inclusion
for index in new_tree.get_split_variables():
variable_inclusion[index] += 1
else:
# update the variable inclusion
for index in new_tree.get_split_variables():
variable_inclusion[index] += 1

if not self.tune:
self.bart.all_trees.append(self.all_trees)
Expand Down Expand Up @@ -331,23 +348,27 @@ def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[
single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw
return inverse_cdf(single_uniform, normalized_weights)

def init_particles(self, tree_id: int) -> List[ParticleTree]:
def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]:
"""Initialize particles."""
p0: ParticleTree = self.all_particles[tree_id]
p0: ParticleTree = self.all_particles[odim][tree_id]
# The old tree does not grow so we update the weight only once
self.update_weight(p0)
self.update_weight(p0, odim)
particles: List[ParticleTree] = [p0]

particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
return particles

def update_weight(self, particle: ParticleTree) -> None:
def update_weight(self, particle: ParticleTree, odim: int) -> None:
"""
Update the weight of a particle.
"""
new_likelihood = self.likelihood_logp(
(self.sum_trees_noi + particle.tree._predict()).flatten()

delta = (
np.identity(self.trees_shape)[odim][:, None, None]
* particle.tree._predict()[None, :, :]
)

new_likelihood = self.likelihood_logp((self.sum_trees_noi + delta).flatten())
particle.log_weight = new_likelihood

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def new_tree(
)
},
idx_leaf_nodes=[0],
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
output=np.zeros((num_observations, shape)).astype(config.floatX),
split_rules=split_rules,
)

Expand Down Expand Up @@ -226,7 +226,7 @@ def _predict(self) -> npt.NDArray[np.float_]:
if self.idx_leaf_nodes is not None:
for node_index in self.idx_leaf_nodes:
leaf_node = self.get_node(node_index)
output[leaf_node.idx_data_points] = leaf_node.value.squeeze()
output[leaf_node.idx_data_points] = leaf_node.value
return output.T

def predict(
Expand Down
18 changes: 12 additions & 6 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ def _sample_posterior(
Indexes of the variables to exclude when computing predictions
"""
stacked_trees = all_trees

if isinstance(X, Variable):
X = X.eval()

if size is None:
size_iter: Union[List, Tuple] = ()
size_iter: Union[List, Tuple] = (1,)
elif isinstance(size, int):
size_iter = [size]
else:
Expand All @@ -60,13 +61,18 @@ def _sample_posterior(

idx = rng.integers(0, len(stacked_trees), size=flatten_size)

pred = np.zeros((flatten_size, X.shape[0], shape))
trees_shape = len(stacked_trees[0])
leaves_shape = shape // trees_shape

pred = np.zeros((flatten_size, trees_shape, leaves_shape, X.shape[0]))

for ind, p in enumerate(pred):
for tree in stacked_trees[idx[ind]]:
p += tree.predict(x=X, excluded=excluded, shape=shape).T
pred.reshape((*size_iter, shape, -1))
return pred
for odim, odim_trees in enumerate(stacked_trees[idx[ind]]):
for tree in odim_trees:
p[odim] += tree.predict(x=X, excluded=excluded, shape=leaves_shape)

# pred.reshape((*size_iter, shape, -1))
return pred.transpose((0, 3, 1, 2)).reshape((*size_iter, -1, shape))


def plot_convergence(
Expand Down
27 changes: 27 additions & 0 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,30 @@ def test_bart_moment(size, expected):
with pm.Model() as model:
pmb.BART("x", X=X, Y=Y, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
argnames="separate_trees,split_rule",
argvalues=[
(False,pmb.ContinuousSplitRule),
(False,pmb.OneHotSplitRule),
(False,pmb.SubsetSplitRule),
(True,pmb.ContinuousSplitRule)
],
ids=["continuous", "one-hot", "subset", "separate-trees"],
)
def test_categorical_model(separate_trees,split_rule):

Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1)

with pm.Model() as model:
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9),
split_rules=[split_rule]*5,
separate_trees=separate_trees)
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
idata = pm.sample(random_seed=3415, tune=300, draws=300)
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)

# Fit should be good enough so right category is selected over 50% of time
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()