Skip to content

Commit a73b53a

Browse files
committed
Addressed most issues flagged in PR comments
1 parent 94c664a commit a73b53a

File tree

5 files changed

+78
-35
lines changed

5 files changed

+78
-35
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pymc_bart.bart import BART
1717
from pymc_bart.pgbart import PGBART
18-
from pymc_bart.split_rules import *
18+
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
1919
from pymc_bart.utils import (
2020
plot_convergence,
2121
plot_pdp,

pymc_bart/bart.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from .tree import Tree
2929
from .utils import TensorLike, _sample_posterior
30+
from .split_rules import SplitRule
3031

3132
__all__ = ["BART"]
3233

@@ -109,7 +110,7 @@ def __new__(
109110
beta: float = 2.0,
110111
response: str = "constant",
111112
split_prior: Optional[List[float]] = None,
112-
split_rules: Optional[List] = None,
113+
split_rules: Optional[SplitRule] = None,
113114
**kwargs,
114115
):
115116
manager = Manager()
@@ -138,7 +139,7 @@ def __new__(
138139
alpha=alpha,
139140
beta=beta,
140141
split_prior=split_prior,
141-
split_rules=split_rules
142+
split_rules=split_rules,
142143
),
143144
)()
144145

pymc_bart/pgbart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pymc_bart.tree import Node, Tree, get_idx_left_child, get_idx_right_child, get_depth
2929
from pymc_bart.split_rules import ContinuousSplitRule
3030

31+
3132
class ParticleTree:
3233
"""Particle tree."""
3334

@@ -169,7 +170,7 @@ def __init__(
169170
idx_data_points=np.arange(self.num_observations, dtype="int32"),
170171
num_observations=self.num_observations,
171172
shape=self.shape,
172-
split_rules = self.split_rules
173+
split_rules=self.split_rules,
173174
)
174175

175176
self.normal = NormalSampler(1, self.shape)
@@ -459,7 +460,7 @@ def grow_tree(
459460

460461
to_left = split_rule.divide(available_splitting_values, split_value)
461462
new_idx_data_points = idx_data_points[to_left], idx_data_points[~to_left]
462-
463+
463464
current_node_children = (
464465
get_idx_left_child(index_leaf_node),
465466
get_idx_right_child(index_leaf_node),
@@ -499,7 +500,6 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
499500
return idx_data_points, available_splitting_values
500501

501502

502-
503503
def draw_leaf_value(
504504
y_mu_pred: npt.NDArray[np.float_],
505505
x_mu: npt.NDArray[np.float_],

pymc_bart/split_rules.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,26 @@
1414

1515
from numba import njit
1616
import numpy as np
17+
from abc import abstractmethod
1718

18-
class ContinuousSplitRule:
19+
20+
class SplitRule:
21+
"""
22+
Abstract template class for a split rule
23+
"""
24+
25+
@staticmethod
26+
@abstractmethod
27+
def get_split_value(available_splitting_values):
28+
pass
29+
30+
@staticmethod
31+
@abstractmethod
32+
def divide(available_splitting_values, split_value):
33+
pass
34+
35+
36+
class ContinuousSplitRule(SplitRule):
1937
"""
2038
Standard continuous split rule: pick a pivot value and split depending on if variable is smaller or greater than the value picked.
2139
"""
@@ -24,7 +42,9 @@ class ContinuousSplitRule:
2442
def get_split_value(available_splitting_values):
2543
split_value = None
2644
if available_splitting_values.size > 1:
27-
idx_selected_splitting_values = int(np.random.random() * len(available_splitting_values))
45+
idx_selected_splitting_values = int(
46+
np.random.random() * len(available_splitting_values)
47+
)
2848
split_value = available_splitting_values[idx_selected_splitting_values]
2949
return split_value
3050

@@ -33,14 +53,19 @@ def get_split_value(available_splitting_values):
3353
def divide(available_splitting_values, split_value):
3454
return available_splitting_values <= split_value
3555

36-
class OneHotSplitRule:
56+
57+
class OneHotSplitRule(SplitRule):
3758
"""Choose a single categorical value and branch on if the variable is that value or not"""
3859

3960
@staticmethod
4061
def get_split_value(available_splitting_values):
4162
split_value = None
42-
if available_splitting_values.size > 1 and not np.all(available_splitting_values==available_splitting_values[0]):
43-
idx_selected_splitting_values = int(np.random.random() * len(available_splitting_values))
63+
if available_splitting_values.size > 1 and not np.all(
64+
available_splitting_values == available_splitting_values[0]
65+
):
66+
idx_selected_splitting_values = int(
67+
np.random.random() * len(available_splitting_values)
68+
)
4469
split_value = available_splitting_values[idx_selected_splitting_values]
4570
return split_value
4671

@@ -49,7 +74,8 @@ def get_split_value(available_splitting_values):
4974
def divide(available_splitting_values, split_value):
5075
return available_splitting_values == split_value
5176

52-
class SubsetSplitRule:
77+
78+
class SubsetSplitRule(SplitRule):
5379
"""
5480
Choose a random subset of the categorical values and branch on if the value is within the chosen set.
5581
This is the approach taken in flexBART paper.
@@ -58,11 +84,16 @@ class SubsetSplitRule:
5884
@staticmethod
5985
def get_split_value(available_splitting_values):
6086
split_value = None
61-
if available_splitting_values.size > 1 and not np.all(available_splitting_values==available_splitting_values[0]):
62-
unique_values = np.unique(available_splitting_values)[:-1] # Remove last one so it always goes to left
87+
if available_splitting_values.size > 1 and not np.all(
88+
available_splitting_values == available_splitting_values[0]
89+
):
90+
unique_values = np.unique(available_splitting_values)[
91+
:-1
92+
] # Remove last one so it always goes to left
6393
while True:
6494
sample = np.random.randint(0, 2, size=len(unique_values)).astype(bool)
65-
if np.any(sample): break
95+
if np.any(sample):
96+
break
6697
split_value = unique_values[sample]
6798
return split_value
6899

pymc_bart/tree.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,7 @@ class Tree:
114114
idx_leaf_nodes : List with the index of the leaf nodes of the tree.
115115
"""
116116

117-
__slots__ = (
118-
"tree_structure",
119-
"output",
120-
"idx_leaf_nodes",
121-
"split_rules"
122-
)
117+
__slots__ = ("tree_structure", "output", "idx_leaf_nodes", "split_rules")
123118

124119
def __init__(
125120
self,
@@ -140,7 +135,7 @@ def new_tree(
140135
idx_data_points: Optional[npt.NDArray[np.int_]],
141136
num_observations: int,
142137
shape: int,
143-
split_rules: List[object]
138+
split_rules: List[object],
144139
) -> "Tree":
145140
return cls(
146141
tree_structure={
@@ -152,7 +147,7 @@ def new_tree(
152147
},
153148
idx_leaf_nodes=[0],
154149
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
155-
split_rules=split_rules
150+
split_rules=split_rules,
156151
)
157152

158153
def __getitem__(self, index) -> Node:
@@ -173,7 +168,12 @@ def copy(self) -> "Tree":
173168
for k, v in self.tree_structure.items()
174169
}
175170
idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None
176-
return Tree(tree_structure=tree, idx_leaf_nodes=idx_leaf_nodes, output=self.output, split_rules=self.split_rules)
171+
return Tree(
172+
tree_structure=tree,
173+
idx_leaf_nodes=idx_leaf_nodes,
174+
output=self.output,
175+
split_rules=self.split_rules,
176+
)
177177

178178
def get_node(self, index: int) -> Node:
179179
return self.tree_structure[index]
@@ -207,7 +207,12 @@ def trim(self) -> "Tree":
207207
)
208208
for k, v in self.tree_structure.items()
209209
}
210-
return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1]), split_rules=self.split_rules)
210+
return Tree(
211+
tree_structure=tree,
212+
idx_leaf_nodes=None,
213+
output=np.array([-1]),
214+
split_rules=self.split_rules,
215+
)
211216

212217
def get_split_variables(self) -> Generator[int, None, None]:
213218
for node in self.tree_structure.values():
@@ -275,21 +280,25 @@ def _traverse_tree(
275280
Leaf node value or mean of leaf node values
276281
"""
277282

278-
x_shape = (1,) if len(X.shape)==1 else X.shape[:-1]
283+
x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1]
279284

280-
stack = [(0, np.ones(x_shape)) ] # (node_index, weight) initial state
281-
p_d = np.zeros( shape + x_shape ) if isinstance(shape, tuple) else np.zeros( (shape,) + x_shape )
285+
stack = [(0, np.ones(x_shape))] # (node_index, weight) initial state
286+
p_d = (
287+
np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape)
288+
)
282289
while stack:
283290
node_index, weights = stack.pop()
284291
node = self.get_node(node_index)
285292
if node.is_leaf_node():
286293
params = node.linear_params
287-
nd_dims = (...,)+(None,)*len(x_shape)
294+
nd_dims = (...,) + (None,) * len(x_shape)
288295
if params is None:
289296
p_d += weights * node.value[nd_dims]
290297
else:
291298
# this produce nonsensical results
292-
p_d += weights * (params[0][nd_dims] + params[1][nd_dims] * X[...,node.idx_split_variable])
299+
p_d += weights * (
300+
params[0][nd_dims] + params[1][nd_dims] * X[..., node.idx_split_variable]
301+
)
293302
# this produce reasonable result
294303
# p_d += weight * node.value.mean()
295304
else:
@@ -301,15 +310,19 @@ def _traverse_tree(
301310
stack.append((left_node_index, weights * prop_nvalue_left))
302311
stack.append((right_node_index, weights * (1 - prop_nvalue_left)))
303312
else:
304-
to_left = self.split_rules[node.idx_split_variable].divide(X[...,node.idx_split_variable],node.value).astype('float')
313+
to_left = (
314+
self.split_rules[node.idx_split_variable]
315+
.divide(X[..., node.idx_split_variable], node.value)
316+
.astype("float")
317+
)
305318
stack.append((left_node_index, weights * to_left))
306319
stack.append((right_node_index, weights * (1 - to_left)))
307320

308-
if len(X.shape)==1: p_d = p_d[...,0]
321+
if len(X.shape) == 1:
322+
p_d = p_d[..., 0]
309323

310324
return p_d
311325

312-
313326
def _traverse_leaf_values(
314327
self, leaf_values: List[npt.NDArray[np.float_]], leaf_n_values: List[int], node_index: int
315328
) -> None:
@@ -328,5 +341,3 @@ def _traverse_leaf_values(
328341
else:
329342
self._traverse_leaf_values(leaf_values, leaf_n_values, get_idx_left_child(node_index))
330343
self._traverse_leaf_values(leaf_values, leaf_n_values, get_idx_right_child(node_index))
331-
332-

0 commit comments

Comments
 (0)