Skip to content

Commit fc665d8

Browse files
committed
Added separate_trees flag that allows training fully separate tree structures instead of just separate leaf node values.
1 parent 6d8987d commit fc665d8

File tree

5 files changed

+155
-90
lines changed

5 files changed

+155
-90
lines changed

pymc_bart/bart.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class BARTRV(RandomVariable):
4040
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
4141
dtype: str = "floatX"
4242
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
43-
all_trees = List[List[Tree]]
43+
all_trees = List[List[List[Tree]]]
4444

4545
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
4646
return dist_params[0].shape[:1]
@@ -96,6 +96,15 @@ class BART(Distribution):
9696
List of SplitRule objects, one per column in input data.
9797
Allows using different split rules for different columns. Default is ContinuousSplitRule.
9898
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
99+
shape: : Optional[Tuple], default None
100+
Specify the output shape. If shape is different from (len(X)) (the default), train a
101+
separate tree for each value in other dimensions.
102+
separate_trees : Optional[bool], default False
103+
When training multiple trees (by setting a shape parameter), the default behavior is to
104+
learn a joint tree structure and only have different leaf values for each.
105+
This flag forces a fully separate tree structure to be trained instead.
106+
This is unnecessary in many cases and is considerably slower, multiplying
107+
run-time roughly by number of dimensions.
99108
100109
Notes
101110
-----
@@ -115,6 +124,7 @@ def __new__(
115124
response: str = "constant",
116125
split_prior: Optional[List[float]] = None,
117126
split_rules: Optional[SplitRule] = None,
127+
separate_trees: Optional[bool] = False,
118128
**kwargs,
119129
):
120130
manager = Manager()
@@ -141,6 +151,7 @@ def __new__(
141151
beta=beta,
142152
split_prior=split_prior,
143153
split_rules=split_rules,
154+
separate_trees=separate_trees,
144155
),
145156
)()
146157

pymc_bart/pgbart.py

Lines changed: 102 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,17 @@ def __init__(
134134
self.missing_data = np.any(np.isnan(self.X))
135135
self.m = self.bart.m
136136
self.response = self.bart.response
137+
137138
shape = initial_values[value_bart.name].shape
139+
138140
self.shape = 1 if len(shape) == 1 else shape[0]
139141

142+
# Set trees_shape (dim for separate tree structures)
143+
# and leaves_shape (dim for leaf node values)
144+
# One of the two is always one, the other equal to self.shape
145+
self.trees_shape = self.shape if self.bart.separate_trees else 1
146+
self.leaves_shape = self.shape if not self.bart.separate_trees else 1
147+
140148
if self.bart.split_prior:
141149
self.alpha_vec = self.bart.split_prior
142150
else:
@@ -153,27 +161,31 @@ def __init__(
153161
self.available_predictors = list(range(self.num_variates))
154162

155163
# if data is binary
164+
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))
165+
156166
y_unique = np.unique(self.bart.Y)
157167
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
158-
self.leaf_sd = 3 / self.m**0.5
168+
self.leaf_sd *= 3 / self.m**0.5
159169
else:
160-
self.leaf_sd = self.bart.Y.std() / self.m**0.5
170+
self.leaf_sd *= self.bart.Y.std() / self.m**0.5
161171

162-
self.running_sd = RunningSd(shape)
172+
self.running_sd = [
173+
RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape)
174+
]
163175

164-
self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
165-
config.floatX
166-
)
176+
self.sum_trees = np.full(
177+
(self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean
178+
).astype(config.floatX)
167179
self.sum_trees_noi = self.sum_trees - init_mean
168180
self.a_tree = Tree.new_tree(
169181
leaf_node_value=init_mean / self.m,
170182
idx_data_points=np.arange(self.num_observations, dtype="int32"),
171183
num_observations=self.num_observations,
172-
shape=self.shape,
184+
shape=self.leaves_shape,
173185
split_rules=self.split_rules,
174186
)
175187

176-
self.normal = NormalSampler(1, self.shape)
188+
self.normal = NormalSampler(1, self.leaves_shape)
177189
self.uniform = UniformSampler(0, 1)
178190
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha, self.bart.beta)
179191
self.ssv = SampleSplittingVariable(self.alpha_vec)
@@ -188,8 +200,10 @@ def __init__(
188200
self.indices = list(range(1, num_particles))
189201
shared = make_shared_replacements(initial_values, vars, model)
190202
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
191-
self.all_particles = [ParticleTree(self.a_tree) for _ in range(self.m)]
192-
self.all_trees = np.array([p.tree for p in self.all_particles])
203+
self.all_particles = [
204+
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
205+
]
206+
self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles])
193207
self.lower = 0
194208
self.iter = 0
195209
super().__init__(vars, shared)
@@ -201,72 +215,75 @@ def astep(self, _):
201215
tree_ids = range(self.lower, upper)
202216
self.lower = upper if upper < self.m else 0
203217

204-
for tree_id in tree_ids:
205-
self.iter += 1
206-
# Compute the sum of trees without the old tree that we are attempting to replace
207-
self.sum_trees_noi = self.sum_trees - self.all_particles[tree_id].tree._predict()
208-
# Generate an initial set of particles
209-
# at the end we return one of these particles as the new tree
210-
particles = self.init_particles(tree_id)
211-
212-
while True:
213-
# Sample each particle (try to grow each tree), except for the first one
214-
stop_growing = True
215-
for p in particles[1:]:
216-
if p.sample_tree(
217-
self.ssv,
218-
self.available_predictors,
219-
self.prior_prob_leaf_node,
220-
self.X,
221-
self.missing_data,
222-
self.sum_trees,
223-
self.leaf_sd,
224-
self.m,
225-
self.response,
226-
self.normal,
227-
self.shape,
228-
):
229-
self.update_weight(p)
230-
if p.expansion_nodes:
231-
stop_growing = False
232-
if stop_growing:
233-
break
234-
235-
# Normalize weights
236-
normalized_weights = self.normalize(particles[1:])
237-
238-
# Resample
239-
particles = self.resample(particles, normalized_weights)
240-
241-
normalized_weights = self.normalize(particles)
242-
# Get the new particle and associated tree
243-
self.all_particles[tree_id], new_tree = self.get_particle_tree(
244-
particles, normalized_weights
245-
)
246-
# Update the sum of trees
247-
new = new_tree._predict()
248-
self.sum_trees = self.sum_trees_noi + new
249-
# To reduce memory usage, we trim the tree
250-
self.all_trees[tree_id] = new_tree.trim()
251-
252-
if self.tune:
253-
# Update the splitting variable and the splitting variable sampler
254-
if self.iter > self.m:
255-
self.ssv = SampleSplittingVariable(self.alpha_vec)
256-
257-
for index in new_tree.get_split_variables():
258-
self.alpha_vec[index] += 1
259-
260-
# update standard deviation at leaf nodes
261-
if self.iter > 2:
262-
self.leaf_sd = self.running_sd.update(new)
263-
else:
264-
self.running_sd.update(new)
218+
for odim in range(self.trees_shape):
219+
for tree_id in tree_ids:
220+
self.iter += 1
221+
# Compute the sum of trees without the old tree that we are attempting to replace
222+
self.sum_trees_noi[odim] = (
223+
self.sum_trees[odim] - self.all_particles[odim][tree_id].tree._predict()
224+
)
225+
# Generate an initial set of particles
226+
# at the end we return one of these particles as the new tree
227+
particles = self.init_particles(tree_id, odim)
228+
229+
while True:
230+
# Sample each particle (try to grow each tree), except for the first one
231+
stop_growing = True
232+
for p in particles[1:]:
233+
if p.sample_tree(
234+
self.ssv,
235+
self.available_predictors,
236+
self.prior_prob_leaf_node,
237+
self.X,
238+
self.missing_data,
239+
self.sum_trees[odim],
240+
self.leaf_sd[odim],
241+
self.m,
242+
self.response,
243+
self.normal,
244+
self.leaves_shape,
245+
):
246+
self.update_weight(p, odim)
247+
if p.expansion_nodes:
248+
stop_growing = False
249+
if stop_growing:
250+
break
251+
252+
# Normalize weights
253+
normalized_weights = self.normalize(particles[1:])
254+
255+
# Resample
256+
particles = self.resample(particles, normalized_weights)
257+
258+
normalized_weights = self.normalize(particles)
259+
# Get the new particle and associated tree
260+
self.all_particles[odim][tree_id], new_tree = self.get_particle_tree(
261+
particles, normalized_weights
262+
)
263+
# Update the sum of trees
264+
new = new_tree._predict()
265+
self.sum_trees[odim] = self.sum_trees_noi[odim] + new
266+
# To reduce memory usage, we trim the tree
267+
self.all_trees[odim][tree_id] = new_tree.trim()
268+
269+
if self.tune:
270+
# Update the splitting variable and the splitting variable sampler
271+
if self.iter > self.m:
272+
self.ssv = SampleSplittingVariable(self.alpha_vec)
273+
274+
for index in new_tree.get_split_variables():
275+
self.alpha_vec[index] += 1
276+
277+
# update standard deviation at leaf nodes
278+
if self.iter > 2:
279+
self.leaf_sd[odim] = self.running_sd[odim].update(new)
280+
else:
281+
self.running_sd[odim].update(new)
265282

266-
else:
267-
# update the variable inclusion
268-
for index in new_tree.get_split_variables():
269-
variable_inclusion[index] += 1
283+
else:
284+
# update the variable inclusion
285+
for index in new_tree.get_split_variables():
286+
variable_inclusion[index] += 1
270287

271288
if not self.tune:
272289
self.bart.all_trees.append(self.all_trees)
@@ -331,23 +348,27 @@ def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[
331348
single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw
332349
return inverse_cdf(single_uniform, normalized_weights)
333350

334-
def init_particles(self, tree_id: int) -> List[ParticleTree]:
351+
def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]:
335352
"""Initialize particles."""
336-
p0: ParticleTree = self.all_particles[tree_id]
353+
p0: ParticleTree = self.all_particles[odim][tree_id]
337354
# The old tree does not grow so we update the weight only once
338-
self.update_weight(p0)
355+
self.update_weight(p0, odim)
339356
particles: List[ParticleTree] = [p0]
340357

341358
particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
342359
return particles
343360

344-
def update_weight(self, particle: ParticleTree) -> None:
361+
def update_weight(self, particle: ParticleTree, odim: int) -> None:
345362
"""
346363
Update the weight of a particle.
347364
"""
348-
new_likelihood = self.likelihood_logp(
349-
(self.sum_trees_noi + particle.tree._predict()).flatten()
365+
366+
delta = (
367+
np.identity(self.trees_shape)[odim][:, None, None]
368+
* particle.tree._predict()[None, :, :]
350369
)
370+
371+
new_likelihood = self.likelihood_logp((self.sum_trees_noi + delta).flatten())
351372
particle.log_weight = new_likelihood
352373

353374
@staticmethod

pymc_bart/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def new_tree(
147147
)
148148
},
149149
idx_leaf_nodes=[0],
150-
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
150+
output=np.zeros((num_observations, shape)).astype(config.floatX),
151151
split_rules=split_rules,
152152
)
153153

@@ -226,7 +226,7 @@ def _predict(self) -> npt.NDArray[np.float_]:
226226
if self.idx_leaf_nodes is not None:
227227
for node_index in self.idx_leaf_nodes:
228228
leaf_node = self.get_node(node_index)
229-
output[leaf_node.idx_data_points] = leaf_node.value.squeeze()
229+
output[leaf_node.idx_data_points] = leaf_node.value
230230
return output.T
231231

232232
def predict(

pymc_bart/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def _sample_posterior(
4444
Indexes of the variables to exclude when computing predictions
4545
"""
4646
stacked_trees = all_trees
47+
4748
if isinstance(X, Variable):
4849
X = X.eval()
4950

5051
if size is None:
51-
size_iter: Union[List, Tuple] = ()
52+
size_iter: Union[List, Tuple] = (1,)
5253
elif isinstance(size, int):
5354
size_iter = [size]
5455
else:
@@ -60,13 +61,18 @@ def _sample_posterior(
6061

6162
idx = rng.integers(0, len(stacked_trees), size=flatten_size)
6263

63-
pred = np.zeros((flatten_size, X.shape[0], shape))
64+
trees_shape = len(stacked_trees[0])
65+
leaves_shape = shape // trees_shape
66+
67+
pred = np.zeros((flatten_size, trees_shape, leaves_shape, X.shape[0]))
6468

6569
for ind, p in enumerate(pred):
66-
for tree in stacked_trees[idx[ind]]:
67-
p += tree.predict(x=X, excluded=excluded, shape=shape).T
68-
pred.reshape((*size_iter, shape, -1))
69-
return pred
70+
for odim, odim_trees in enumerate(stacked_trees[idx[ind]]):
71+
for tree in odim_trees:
72+
p[odim] += tree.predict(x=X, excluded=excluded, shape=leaves_shape)
73+
74+
# pred.reshape((*size_iter, shape, -1))
75+
return pred.transpose((0, 3, 1, 2)).reshape((*size_iter, -1, shape))
7076

7177

7278
def plot_convergence(

tests/test_bart.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,30 @@ def test_bart_moment(size, expected):
217217
with pm.Model() as model:
218218
pmb.BART("x", X=X, Y=Y, size=size)
219219
assert_moment_is_expected(model, expected)
220+
221+
222+
@pytest.mark.parametrize(
223+
argnames="separate_trees,split_rule",
224+
argvalues=[
225+
(False,pmb.ContinuousSplitRule),
226+
(False,pmb.OneHotSplitRule),
227+
(False,pmb.SubsetSplitRule),
228+
(True,pmb.ContinuousSplitRule)
229+
],
230+
ids=["continuous", "one-hot", "subset", "separate-trees"],
231+
)
232+
def test_categorical_model(separate_trees,split_rule):
233+
234+
Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
235+
X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1)
236+
237+
with pm.Model() as model:
238+
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9),
239+
split_rules=[split_rule]*5,
240+
separate_trees=separate_trees)
241+
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
242+
idata = pm.sample(random_seed=3415, tune=300, draws=300)
243+
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
244+
245+
# Fit should be good enough so right category is selected over 50% of time
246+
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()

0 commit comments

Comments
 (0)