Skip to content

Commit 8f5b293

Browse files
author
juanitorduz
committed
improve code logic
1 parent 30fd478 commit 8f5b293

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

pymc_bart/tree.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Node:
2929
index : int
3030
value : float
3131
idx_data_points : Optional[npt.NDArray[np.int_]]
32-
idx_split_variable : Optional[npt.NDArray[np.int_]],
32+
idx_split_variable : int
3333
linear_params: Optional[List[float]] = None
3434
"""
3535

@@ -190,7 +190,7 @@ def _predict(self) -> npt.NDArray[np.float_]:
190190
return output.T
191191

192192
def predict(
193-
self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None
193+
self, x: npt.NDArray[np.float_], m: int, excluded: Optional[List[int]] = None
194194
) -> npt.NDArray[np.float_]:
195195
"""
196196
Predict output of tree for an (un)observed point x.
@@ -199,6 +199,8 @@ def predict(
199199
----------
200200
x : npt.NDArray[np.float_]
201201
Unobserved point
202+
m : int
203+
Number of trees
202204
excluded: Optional[List[int]]
203205
Indexes of the variables to exclude when computing predictions
204206
@@ -209,12 +211,14 @@ def predict(
209211
"""
210212
if excluded is None:
211213
excluded = []
212-
return self._traverse_tree(x, 0, excluded)
214+
return self._traverse_tree(x=x, m=m, node_index=0, split_variable=-1, excluded=excluded)
213215

214216
def _traverse_tree(
215217
self,
216218
x: npt.NDArray[np.float_],
219+
m: int,
217220
node_index: int,
221+
split_variable: int = -1,
218222
excluded: Optional[List[int]] = None,
219223
) -> npt.NDArray[np.float_]:
220224
"""
@@ -224,8 +228,12 @@ def _traverse_tree(
224228
----------
225229
x : npt.NDArray[np.float_]
226230
Unobserved point
231+
m : int
232+
Number of trees
227233
node_index : int
228234
Index of the node to start the traversal from
235+
split_variable : int
236+
Index of the variable used to split the node
229237
excluded: Optional[List[int]]
230238
Indexes of the variables to exclude when computing predictions
231239
@@ -235,13 +243,15 @@ def _traverse_tree(
235243
Leaf node value or mean of leaf node values
236244
"""
237245
current_node = self.get_node(node_index)
238-
239246
if current_node.is_leaf_node():
240-
if current_node.linear_params is not None:
247+
if current_node.linear_params is None:
241248
return np.array(current_node.value)
249+
242250
x = x[split_variable].item()
243251
y_x = current_node.linear_params[0] + current_node.linear_params[1] * x
244-
return y_x / m
252+
return np.array(y_x / m)
253+
254+
split_variable = current_node.idx_split_variable
245255

246256
if excluded is not None and current_node.idx_split_variable in excluded:
247257
leaf_values: List[float] = []
@@ -252,7 +262,9 @@ def _traverse_tree(
252262
next_node = current_node.get_idx_left_child()
253263
else:
254264
next_node = current_node.get_idx_right_child()
255-
return self._traverse_tree(x=x, node_index=next_node, excluded=excluded)
265+
return self._traverse_tree(
266+
x=x, m=m, node_index=next_node, split_variable=split_variable, excluded=excluded
267+
)
256268

257269
def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None:
258270
"""

pymc_bart/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
def _sample_posterior(
2222
all_trees: List[List[Tree]],
2323
X: TensorLike,
24+
m: int,
2425
rng: np.random.Generator,
2526
size: Optional[Union[int, Tuple[int, ...]]] = None,
2627
excluded: Optional[npt.NDArray[np.int_]] = None,
@@ -35,6 +36,8 @@ def _sample_posterior(
3536
X : tensor-like
3637
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
3738
out-of-sample predictions.
39+
m : int
40+
Number of trees
3841
rng : NumPy RandomGenerator
3942
size : int or tuple
4043
Number of samples.
@@ -57,7 +60,7 @@ def _sample_posterior(
5760
flatten_size *= s
5861

5962
idx = rng.integers(0, len(stacked_trees), size=flatten_size)
60-
shape = stacked_trees[0][0].predict(X[0]).size
63+
shape = stacked_trees[0][0].predict(x=X[0], m=m).size
6164

6265
pred = np.zeros((flatten_size, X.shape[0], shape))
6366

@@ -220,6 +223,8 @@ def plot_dependence(
220223
-------
221224
axes: matplotlib axes
222225
"""
226+
m: int = bartrv.owner.op.m
227+
223228
if kind not in ["pdp", "ice"]:
224229
raise ValueError(f"kind={kind} is not suported. Available option are 'pdp' or 'ice'")
225230

@@ -294,15 +299,15 @@ def plot_dependence(
294299
new_X[:, indices_mi] = X[:, indices_mi]
295300
new_X[:, i] = x_i
296301
y_pred.append(
297-
np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 1)
302+
np.mean(_sample_posterior(all_trees, X=new_X, m=m, rng=rng, size=samples), 1)
298303
)
299304
new_x_target.append(new_x_i)
300305
else:
301306
for instance in instances_ary:
302307
new_X = X[idx_s]
303308
new_X[:, indices_mi] = X[:, indices_mi][instance]
304309
y_pred.append(
305-
np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 0)
310+
np.mean(_sample_posterior(all_trees, X=new_X, m=m, rng=rng, size=samples), 0)
306311
)
307312
new_x_target.append(new_X[:, i])
308313
y_mins.append(np.min(y_pred))
@@ -445,6 +450,8 @@ def plot_variable_importance(
445450
"""
446451
_, axes = plt.subplots(2, 1, figsize=figsize)
447452

453+
m: int = bartrv.owner.op.m
454+
448455
if hasattr(X, "columns") and hasattr(X, "values"):
449456
labels = X.columns
450457
X = X.values
@@ -474,13 +481,13 @@ def plot_variable_importance(
474481

475482
all_trees = bartrv.owner.op.all_trees
476483

477-
predicted_all = _sample_posterior(all_trees, X=X, rng=rng, size=samples, excluded=None)
484+
predicted_all = _sample_posterior(all_trees, X=X, m=m, rng=rng, size=samples, excluded=None)
478485

479486
ev_mean = np.zeros(len(var_imp))
480487
ev_hdi = np.zeros((len(var_imp), 2))
481488
for idx, subset in enumerate(subsets):
482489
predicted_subset = _sample_posterior(
483-
all_trees=all_trees, X=X, rng=rng, size=samples, excluded=subset
490+
all_trees=all_trees, X=X, m=m, rng=rng, size=samples, excluded=subset
484491
)
485492
pearson = np.zeros(samples)
486493
for j in range(samples):

0 commit comments

Comments
 (0)