Skip to content

Commit 93e7026

Browse files
authored
bart refactor (#43)
* bart refactor * fix test
1 parent 036eaf5 commit 93e7026

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

pymc_experimental/bart/pgbart.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,12 @@ def init_particles(self, tree_id: int) -> np.ndarray:
227227
self.m,
228228
self.normal,
229229
)
230+
230231
# The old tree and the one with new leafs do not grow so we update the weights only once
231232
self.update_weight(p0, old=True)
232233
self.update_weight(p1, old=True)
233-
234234
particles = [p0, p1]
235+
235236
for _ in self.indices:
236237
pt = ParticleTree(self.a_tree)
237238
if self.tune:
@@ -396,15 +397,9 @@ def grow_tree(
396397
index_selected_predictor = ssv.rvs()
397398
selected_predictor = available_predictors[index_selected_predictor]
398399
available_splitting_values = X[idx_data_points, selected_predictor]
399-
if missing_data:
400-
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
401-
available_splitting_values = available_splitting_values[
402-
~np.isnan(available_splitting_values)
403-
]
400+
split_value = get_split_value(available_splitting_values, idx_data_points, missing_data)
404401

405-
if available_splitting_values.size > 0:
406-
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
407-
split_value = available_splitting_values[idx_selected_splitting_values]
402+
if split_value is not None:
408403

409404
new_idx_data_points = get_new_idx_data_points(
410405
split_value, idx_data_points, selected_predictor, X
@@ -439,7 +434,7 @@ def grow_tree(
439434
)
440435

441436
# update tree nodes and indexes
442-
tree.delete_node(index_leaf_node)
437+
tree.delete_leaf_node(index_leaf_node)
443438
tree.set_node(index_leaf_node, new_split_node)
444439
tree.set_node(new_nodes[0].index, new_nodes[0])
445440
tree.set_node(new_nodes[1].index, new_nodes[1])
@@ -456,6 +451,21 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
456451
return left_node_idx_data_points, right_node_idx_data_points
457452

458453

454+
def get_split_value(available_splitting_values, idx_data_points, missing_data):
455+
456+
if missing_data:
457+
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
458+
available_splitting_values = available_splitting_values[
459+
~np.isnan(available_splitting_values)
460+
]
461+
462+
if available_splitting_values.size > 0:
463+
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
464+
split_value = available_splitting_values[idx_selected_splitting_values]
465+
466+
return split_value
467+
468+
459469
def draw_leaf_value(Y_mu_pred, mean, m, normal, kf):
460470
"""Draw Gaussian distributed leaf values."""
461471
if Y_mu_pred.size == 0:

pymc_experimental/bart/tree.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,8 @@ def set_node(self, index, node):
6868
if isinstance(node, LeafNode):
6969
self.idx_leaf_nodes.append(index)
7070

71-
def delete_node(self, index):
72-
current_node = self.get_node(index)
73-
if isinstance(current_node, LeafNode):
74-
self.idx_leaf_nodes.remove(index)
71+
def delete_leaf_node(self, index):
72+
self.idx_leaf_nodes.remove(index)
7573
del self.tree_structure[index]
7674

7775
def trim(self):

0 commit comments

Comments
 (0)