Skip to content

Commit 144e048

Browse files
Linear nodes attempt (v2) (#79)
* linear nodes init improve code logic add m param add m param to tests add signature in pbbart to add linear nodes * fix lint * Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * address comments * fix arg error * add bracket * clean func * fix rm init param linear_params * clean up * minor improvements * fix mean denominator * fix shape * add tests * fix working implementation * try fix shapre vect * undo last commit * Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * address comments * change default response * better unpacking * some type hints * fix trim method * handle zero variance fast linear fit * simplify condition * linear response in more than 1d * make node value type an array so that is compatible with vectorization * rm .item() from fast linear mdoel output * improve missing values filters * improve missing values filters --------- Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
1 parent b4043ba commit 144e048

File tree

7 files changed

+222
-61
lines changed

7 files changed

+222
-61
lines changed

.pylintrc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ good-names=i,
256256
rv,
257257
new_X,
258258
new_y,
259+
a,
260+
b,
261+
n,
259262

260263

261264
# Include a hint for the correct naming format with invalid-name

pymc_bart/bart.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None,
5252
else:
5353
return np.full(cls.Y.shape[0], cls.Y.mean())
5454
else:
55-
return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze().T
55+
return _sample_posterior(cls.all_trees, cls.X, cls.m, rng=rng).squeeze().T
5656

5757

5858
bart = BARTRV()
@@ -72,6 +72,9 @@ class BART(Distribution):
7272
The response vector.
7373
m : int
7474
Number of trees
75+
response : str
76+
How the leaf_node values are computed. Available options are ``constant``, ``linear`` or
77+
``mix``. Defaults to ``constant``.
7578
alpha : float
7679
Control the prior probability over the depth of the trees. Even when it can takes values in
7780
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
@@ -88,6 +91,7 @@ def __new__(
8891
Y: TensorLike,
8992
m: int = 50,
9093
alpha: float = 0.25,
94+
response: str = "constant",
9195
split_prior: Optional[List[float]] = None,
9296
**kwargs,
9397
):
@@ -110,6 +114,7 @@ def __new__(
110114
X=X,
111115
Y=Y,
112116
m=m,
117+
response=response,
113118
alpha=alpha,
114119
split_prior=split_prior,
115120
),

pymc_bart/pgbart.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def sample_tree(
5454
missing_data,
5555
sum_trees,
5656
m,
57+
response,
5758
normal,
5859
shape,
5960
) -> bool:
@@ -73,6 +74,7 @@ def sample_tree(
7374
missing_data,
7475
sum_trees,
7576
m,
77+
response,
7678
normal,
7779
self.kfactor,
7880
shape,
@@ -131,6 +133,7 @@ def __init__(
131133

132134
self.missing_data = np.any(np.isnan(self.X))
133135
self.m = self.bart.m
136+
self.response = self.bart.response
134137
shape = initial_values[value_bart.name].shape
135138
self.shape = 1 if len(shape) == 1 else shape[0]
136139

@@ -160,6 +163,7 @@ def __init__(
160163
num_observations=self.num_observations,
161164
shape=self.shape,
162165
)
166+
163167
self.normal = NormalSampler(mu_std, self.shape)
164168
self.uniform = UniformSampler(0, 1)
165169
self.uniform_kf = UniformSampler(0.33, 0.75, self.shape)
@@ -209,6 +213,7 @@ def astep(self, _):
209213
self.missing_data,
210214
self.sum_trees,
211215
self.m,
216+
self.response,
212217
self.normal,
213218
self.shape,
214219
):
@@ -393,6 +398,7 @@ def grow_tree(
393398
missing_data,
394399
sum_trees,
395400
m,
401+
response,
396402
normal,
397403
kfactor,
398404
shape,
@@ -402,8 +408,10 @@ def grow_tree(
402408

403409
index_selected_predictor = ssv.rvs()
404410
selected_predictor = available_predictors[index_selected_predictor]
405-
available_splitting_values = X[idx_data_points, selected_predictor]
406-
split_value = get_split_value(available_splitting_values, idx_data_points, missing_data)
411+
idx_data_points, available_splitting_values = filter_missing_values(
412+
X[idx_data_points, selected_predictor], idx_data_points, missing_data
413+
)
414+
split_value = get_split_value(available_splitting_values)
407415

408416
if split_value is None:
409417
return None
@@ -415,18 +423,24 @@ def grow_tree(
415423
get_idx_right_child(index_leaf_node),
416424
)
417425

426+
if response == "mix":
427+
response = "linear" if np.random.random() >= 0.5 else "constant"
428+
418429
for idx in range(2):
419430
idx_data_point = new_idx_data_points[idx]
420-
node_value = draw_leaf_value(
421-
sum_trees[:, idx_data_point],
422-
m,
423-
normal.rvs() * kfactor,
424-
shape,
431+
node_value, linear_params = draw_leaf_value(
432+
y_mu_pred=sum_trees[:, idx_data_point],
433+
x_mu=X[idx_data_point, selected_predictor],
434+
m=m,
435+
norm=normal.rvs() * kfactor,
436+
shape=shape,
437+
response=response,
425438
)
426439

427440
new_node = Node.new_leaf_node(
428441
value=node_value,
429442
idx_data_points=idx_data_point,
443+
linear_params=linear_params,
430444
)
431445
tree.set_node(current_node_children[idx], new_node)
432446

@@ -440,39 +454,52 @@ def get_new_idx_data_points(available_splitting_values, split_value, idx_data_po
440454
return idx_data_points[split_idx], idx_data_points[~split_idx]
441455

442456

443-
def get_split_value(available_splitting_values, idx_data_points, missing_data):
457+
def filter_missing_values(available_splitting_values, idx_data_points, missing_data):
444458
if missing_data:
445-
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
446-
available_splitting_values = available_splitting_values[
447-
~np.isnan(available_splitting_values)
448-
]
459+
mask = ~np.isnan(available_splitting_values)
460+
idx_data_points = idx_data_points[mask]
461+
available_splitting_values = available_splitting_values[mask]
462+
return idx_data_points, available_splitting_values
463+
449464

465+
def get_split_value(available_splitting_values):
450466
split_value = None
451467
if available_splitting_values.size > 0:
452468
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
453469
split_value = available_splitting_values[idx_selected_splitting_values]
454-
455470
return split_value
456471

457472

458-
@njit
459-
def draw_leaf_value(y_mu_pred, m, norm, shape):
473+
def draw_leaf_value(
474+
y_mu_pred: npt.NDArray[np.float_],
475+
x_mu: npt.NDArray[np.float_],
476+
m: int,
477+
norm: npt.NDArray[np.float_],
478+
shape: int,
479+
response: str,
480+
) -> Tuple[npt.NDArray[np.float_], Optional[npt.NDArray[np.float_]]]:
460481
"""Draw Gaussian distributed leaf values."""
482+
linear_params = None
483+
mu_mean = np.empty(shape)
461484
if y_mu_pred.size == 0:
462-
return np.zeros(shape)
485+
return np.zeros(shape), linear_params
463486

464487
if y_mu_pred.size == 1:
465488
mu_mean = np.full(shape, y_mu_pred.item() / m)
466489
else:
467-
mu_mean = fast_mean(y_mu_pred) / m
490+
if response == "constant":
491+
mu_mean = fast_mean(y_mu_pred) / m
492+
if response == "linear":
493+
y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred)
494+
mu_mean = y_fit / m
468495

469-
return norm + mu_mean
496+
draw = norm + mu_mean
497+
return draw, linear_params
470498

471499

472500
@njit
473-
def fast_mean(ari):
501+
def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
474502
"""Use Numba to speed up the computation of the mean."""
475-
476503
if ari.ndim == 1:
477504
count = ari.shape[0]
478505
suma = 0
@@ -488,6 +515,31 @@ def fast_mean(ari):
488515
return res / count
489516

490517

518+
@njit
519+
def fast_linear_fit(
520+
x: npt.NDArray[np.float_], y: npt.NDArray[np.float_]
521+
) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]:
522+
n = len(x)
523+
524+
xbar = np.sum(x) / n
525+
ybar = np.sum(y, axis=1) / n
526+
527+
x_diff = x - xbar
528+
y_diff = y - np.expand_dims(ybar, axis=1)
529+
530+
x_var = np.dot(x_diff, x_diff.T)
531+
532+
if x_var == 0:
533+
b = np.zeros(y.shape[0])
534+
else:
535+
b = np.dot(x_diff, y_diff.T) / x_var
536+
537+
a = ybar - b * xbar
538+
539+
y_fit = np.expand_dims(a, axis=1) + np.expand_dims(b, axis=1) * x
540+
return y_fit.T, [a, b]
541+
542+
491543
def discrete_uniform_sampler(upper_value):
492544
"""Draw from the uniform distribution with bounds [0, upper_value).
493545

0 commit comments

Comments
 (0)