Skip to content

Commit 7d3edc7

Browse files
author
juanitorduz
committed
add signature in pbbart to add linear nodes
1 parent 8cf7028 commit 7d3edc7

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

pymc_bart/pgbart.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def sample_tree(
5454
missing_data,
5555
sum_trees,
5656
m,
57+
linear_fit,
58+
response,
5759
normal,
5860
shape,
5961
) -> bool:
@@ -73,6 +75,8 @@ def sample_tree(
7375
missing_data,
7476
sum_trees,
7577
m,
78+
linear_fit,
79+
response,
7680
normal,
7781
self.kfactor,
7882
shape,
@@ -131,6 +135,7 @@ def __init__(
131135

132136
self.missing_data = np.any(np.isnan(self.X))
133137
self.m = self.bart.m
138+
self.response = self.bart.response
134139
shape = initial_values[value_bart.name].shape
135140
self.shape = 1 if len(shape) == 1 else shape[0]
136141

@@ -160,6 +165,9 @@ def __init__(
160165
num_observations=self.num_observations,
161166
shape=self.shape,
162167
)
168+
169+
self.linear_fit = fast_linear_fit()
170+
163171
self.normal = NormalSampler(mu_std, self.shape)
164172
self.uniform = UniformSampler(0, 1)
165173
self.uniform_kf = UniformSampler(0.33, 0.75, self.shape)
@@ -209,6 +217,8 @@ def astep(self, _):
209217
self.missing_data,
210218
self.sum_trees,
211219
self.m,
220+
self.linear_fit,
221+
self.response,
212222
self.normal,
213223
self.shape,
214224
):
@@ -393,6 +403,8 @@ def grow_tree(
393403
missing_data,
394404
sum_trees,
395405
m,
406+
linear_fit,
407+
response,
396408
normal,
397409
kfactor,
398410
shape,
@@ -602,3 +614,27 @@ def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
602614
function = pytensor_function([inarray0], out_list[0])
603615
function.trust_input = True
604616
return function
617+
618+
619+
def fast_linear_fit():
620+
"""If available use Numba to speed up the computation of the linear fit"""
621+
622+
def linear_fit(X, Y):
623+
624+
n = len(Y)
625+
xbar = np.sum(X) / n
626+
ybar = np.sum(Y) / n
627+
628+
b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar**2)
629+
a = ybar - b * xbar
630+
631+
y_fit = a + b * X
632+
return y_fit, (a, b)
633+
634+
try:
635+
from numba import jit
636+
637+
return jit(linear_fit)
638+
639+
except ImportError:
640+
return linear_fit

0 commit comments

Comments
 (0)