@@ -54,6 +54,8 @@ def sample_tree(
54
54
missing_data ,
55
55
sum_trees ,
56
56
m ,
57
+ linear_fit ,
58
+ response ,
57
59
normal ,
58
60
shape ,
59
61
) -> bool :
@@ -73,6 +75,8 @@ def sample_tree(
73
75
missing_data ,
74
76
sum_trees ,
75
77
m ,
78
+ linear_fit ,
79
+ response ,
76
80
normal ,
77
81
self .kfactor ,
78
82
shape ,
@@ -131,6 +135,7 @@ def __init__(
131
135
132
136
self .missing_data = np .any (np .isnan (self .X ))
133
137
self .m = self .bart .m
138
+ self .response = self .bart .response
134
139
shape = initial_values [value_bart .name ].shape
135
140
self .shape = 1 if len (shape ) == 1 else shape [0 ]
136
141
@@ -160,6 +165,9 @@ def __init__(
160
165
num_observations = self .num_observations ,
161
166
shape = self .shape ,
162
167
)
168
+
169
+ self .linear_fit = fast_linear_fit ()
170
+
163
171
self .normal = NormalSampler (mu_std , self .shape )
164
172
self .uniform = UniformSampler (0 , 1 )
165
173
self .uniform_kf = UniformSampler (0.33 , 0.75 , self .shape )
@@ -209,6 +217,8 @@ def astep(self, _):
209
217
self .missing_data ,
210
218
self .sum_trees ,
211
219
self .m ,
220
+ self .linear_fit ,
221
+ self .response ,
212
222
self .normal ,
213
223
self .shape ,
214
224
):
@@ -393,6 +403,8 @@ def grow_tree(
393
403
missing_data ,
394
404
sum_trees ,
395
405
m ,
406
+ linear_fit ,
407
+ response ,
396
408
normal ,
397
409
kfactor ,
398
410
shape ,
@@ -602,3 +614,27 @@ def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
602
614
function = pytensor_function ([inarray0 ], out_list [0 ])
603
615
function .trust_input = True
604
616
return function
617
+
618
+
619
+ def fast_linear_fit ():
620
+ """If available use Numba to speed up the computation of the linear fit"""
621
+
622
+ def linear_fit (X , Y ):
623
+
624
+ n = len (Y )
625
+ xbar = np .sum (X ) / n
626
+ ybar = np .sum (Y ) / n
627
+
628
+ b = (X @ Y - n * xbar * ybar ) / (X @ X - n * xbar ** 2 )
629
+ a = ybar - b * xbar
630
+
631
+ y_fit = a + b * X
632
+ return y_fit , (a , b )
633
+
634
+ try :
635
+ from numba import jit
636
+
637
+ return jit (linear_fit )
638
+
639
+ except ImportError :
640
+ return linear_fit
0 commit comments