@@ -54,6 +54,7 @@ def sample_tree(
54
54
missing_data ,
55
55
sum_trees ,
56
56
m ,
57
+ response ,
57
58
normal ,
58
59
shape ,
59
60
) -> bool :
@@ -73,6 +74,7 @@ def sample_tree(
73
74
missing_data ,
74
75
sum_trees ,
75
76
m ,
77
+ response ,
76
78
normal ,
77
79
self .kfactor ,
78
80
shape ,
@@ -131,6 +133,7 @@ def __init__(
131
133
132
134
self .missing_data = np .any (np .isnan (self .X ))
133
135
self .m = self .bart .m
136
+ self .response = self .bart .response
134
137
shape = initial_values [value_bart .name ].shape
135
138
self .shape = 1 if len (shape ) == 1 else shape [0 ]
136
139
@@ -160,6 +163,7 @@ def __init__(
160
163
num_observations = self .num_observations ,
161
164
shape = self .shape ,
162
165
)
166
+
163
167
self .normal = NormalSampler (mu_std , self .shape )
164
168
self .uniform = UniformSampler (0 , 1 )
165
169
self .uniform_kf = UniformSampler (0.33 , 0.75 , self .shape )
@@ -209,6 +213,7 @@ def astep(self, _):
209
213
self .missing_data ,
210
214
self .sum_trees ,
211
215
self .m ,
216
+ self .response ,
212
217
self .normal ,
213
218
self .shape ,
214
219
):
@@ -393,6 +398,7 @@ def grow_tree(
393
398
missing_data ,
394
399
sum_trees ,
395
400
m ,
401
+ response ,
396
402
normal ,
397
403
kfactor ,
398
404
shape ,
@@ -402,8 +408,10 @@ def grow_tree(
402
408
403
409
index_selected_predictor = ssv .rvs ()
404
410
selected_predictor = available_predictors [index_selected_predictor ]
405
- available_splitting_values = X [idx_data_points , selected_predictor ]
406
- split_value = get_split_value (available_splitting_values , idx_data_points , missing_data )
411
+ idx_data_points , available_splitting_values = filter_missing_values (
412
+ X [idx_data_points , selected_predictor ], idx_data_points , missing_data
413
+ )
414
+ split_value = get_split_value (available_splitting_values )
407
415
408
416
if split_value is None :
409
417
return None
@@ -415,18 +423,24 @@ def grow_tree(
415
423
get_idx_right_child (index_leaf_node ),
416
424
)
417
425
426
+ if response == "mix" :
427
+ response = "linear" if np .random .random () >= 0.5 else "constant"
428
+
418
429
for idx in range (2 ):
419
430
idx_data_point = new_idx_data_points [idx ]
420
- node_value = draw_leaf_value (
421
- sum_trees [:, idx_data_point ],
422
- m ,
423
- normal .rvs () * kfactor ,
424
- shape ,
431
+ node_value , linear_params = draw_leaf_value (
432
+ y_mu_pred = sum_trees [:, idx_data_point ],
433
+ x_mu = X [idx_data_point , selected_predictor ],
434
+ m = m ,
435
+ norm = normal .rvs () * kfactor ,
436
+ shape = shape ,
437
+ response = response ,
425
438
)
426
439
427
440
new_node = Node .new_leaf_node (
428
441
value = node_value ,
429
442
idx_data_points = idx_data_point ,
443
+ linear_params = linear_params ,
430
444
)
431
445
tree .set_node (current_node_children [idx ], new_node )
432
446
@@ -440,39 +454,52 @@ def get_new_idx_data_points(available_splitting_values, split_value, idx_data_po
440
454
return idx_data_points [split_idx ], idx_data_points [~ split_idx ]
441
455
442
456
443
- def get_split_value (available_splitting_values , idx_data_points , missing_data ):
457
+ def filter_missing_values (available_splitting_values , idx_data_points , missing_data ):
444
458
if missing_data :
445
- idx_data_points = idx_data_points [~ np .isnan (available_splitting_values )]
446
- available_splitting_values = available_splitting_values [
447
- ~ np .isnan (available_splitting_values )
448
- ]
459
+ mask = ~ np .isnan (available_splitting_values )
460
+ idx_data_points = idx_data_points [mask ]
461
+ available_splitting_values = available_splitting_values [mask ]
462
+ return idx_data_points , available_splitting_values
463
+
449
464
465
+ def get_split_value (available_splitting_values ):
450
466
split_value = None
451
467
if available_splitting_values .size > 0 :
452
468
idx_selected_splitting_values = discrete_uniform_sampler (len (available_splitting_values ))
453
469
split_value = available_splitting_values [idx_selected_splitting_values ]
454
-
455
470
return split_value
456
471
457
472
458
- @njit
459
- def draw_leaf_value (y_mu_pred , m , norm , shape ):
473
+ def draw_leaf_value (
474
+ y_mu_pred : npt .NDArray [np .float_ ],
475
+ x_mu : npt .NDArray [np .float_ ],
476
+ m : int ,
477
+ norm : npt .NDArray [np .float_ ],
478
+ shape : int ,
479
+ response : str ,
480
+ ) -> Tuple [npt .NDArray [np .float_ ], Optional [npt .NDArray [np .float_ ]]]:
460
481
"""Draw Gaussian distributed leaf values."""
482
+ linear_params = None
483
+ mu_mean = np .empty (shape )
461
484
if y_mu_pred .size == 0 :
462
- return np .zeros (shape )
485
+ return np .zeros (shape ), linear_params
463
486
464
487
if y_mu_pred .size == 1 :
465
488
mu_mean = np .full (shape , y_mu_pred .item () / m )
466
489
else :
467
- mu_mean = fast_mean (y_mu_pred ) / m
490
+ if response == "constant" :
491
+ mu_mean = fast_mean (y_mu_pred ) / m
492
+ if response == "linear" :
493
+ y_fit , linear_params = fast_linear_fit (x = x_mu , y = y_mu_pred )
494
+ mu_mean = y_fit / m
468
495
469
- return norm + mu_mean
496
+ draw = norm + mu_mean
497
+ return draw , linear_params
470
498
471
499
472
500
@njit
473
- def fast_mean (ari ) :
501
+ def fast_mean (ari : npt . NDArray [ np . float_ ]) -> Union [ float , npt . NDArray [ np . float_ ]] :
474
502
"""Use Numba to speed up the computation of the mean."""
475
-
476
503
if ari .ndim == 1 :
477
504
count = ari .shape [0 ]
478
505
suma = 0
@@ -488,6 +515,31 @@ def fast_mean(ari):
488
515
return res / count
489
516
490
517
518
+ @njit
519
+ def fast_linear_fit (
520
+ x : npt .NDArray [np .float_ ], y : npt .NDArray [np .float_ ]
521
+ ) -> Tuple [npt .NDArray [np .float_ ], List [npt .NDArray [np .float_ ]]]:
522
+ n = len (x )
523
+
524
+ xbar = np .sum (x ) / n
525
+ ybar = np .sum (y , axis = 1 ) / n
526
+
527
+ x_diff = x - xbar
528
+ y_diff = y - np .expand_dims (ybar , axis = 1 )
529
+
530
+ x_var = np .dot (x_diff , x_diff .T )
531
+
532
+ if x_var == 0 :
533
+ b = np .zeros (y .shape [0 ])
534
+ else :
535
+ b = np .dot (x_diff , y_diff .T ) / x_var
536
+
537
+ a = ybar - b * xbar
538
+
539
+ y_fit = np .expand_dims (a , axis = 1 ) + np .expand_dims (b , axis = 1 ) * x
540
+ return y_fit .T , [a , b ]
541
+
542
+
491
543
def discrete_uniform_sampler (upper_value ):
492
544
"""Draw from the uniform distribution with bounds [0, upper_value).
493
545
0 commit comments