@@ -227,11 +227,12 @@ def init_particles(self, tree_id: int) -> np.ndarray:
227
227
self .m ,
228
228
self .normal ,
229
229
)
230
+
230
231
# The old tree and the one with new leafs do not grow so we update the weights only once
231
232
self .update_weight (p0 , old = True )
232
233
self .update_weight (p1 , old = True )
233
-
234
234
particles = [p0 , p1 ]
235
+
235
236
for _ in self .indices :
236
237
pt = ParticleTree (self .a_tree )
237
238
if self .tune :
@@ -396,15 +397,9 @@ def grow_tree(
396
397
index_selected_predictor = ssv .rvs ()
397
398
selected_predictor = available_predictors [index_selected_predictor ]
398
399
available_splitting_values = X [idx_data_points , selected_predictor ]
399
- if missing_data :
400
- idx_data_points = idx_data_points [~ np .isnan (available_splitting_values )]
401
- available_splitting_values = available_splitting_values [
402
- ~ np .isnan (available_splitting_values )
403
- ]
400
+ split_value = get_split_value (available_splitting_values , idx_data_points , missing_data )
404
401
405
- if available_splitting_values .size > 0 :
406
- idx_selected_splitting_values = discrete_uniform_sampler (len (available_splitting_values ))
407
- split_value = available_splitting_values [idx_selected_splitting_values ]
402
+ if split_value is not None :
408
403
409
404
new_idx_data_points = get_new_idx_data_points (
410
405
split_value , idx_data_points , selected_predictor , X
@@ -439,7 +434,7 @@ def grow_tree(
439
434
)
440
435
441
436
# update tree nodes and indexes
442
- tree .delete_node (index_leaf_node )
437
+ tree .delete_leaf_node (index_leaf_node )
443
438
tree .set_node (index_leaf_node , new_split_node )
444
439
tree .set_node (new_nodes [0 ].index , new_nodes [0 ])
445
440
tree .set_node (new_nodes [1 ].index , new_nodes [1 ])
@@ -456,6 +451,21 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
456
451
return left_node_idx_data_points , right_node_idx_data_points
457
452
458
453
454
+ def get_split_value (available_splitting_values , idx_data_points , missing_data ):
455
+
456
+ if missing_data :
457
+ idx_data_points = idx_data_points [~ np .isnan (available_splitting_values )]
458
+ available_splitting_values = available_splitting_values [
459
+ ~ np .isnan (available_splitting_values )
460
+ ]
461
+
462
+ if available_splitting_values .size > 0 :
463
+ idx_selected_splitting_values = discrete_uniform_sampler (len (available_splitting_values ))
464
+ split_value = available_splitting_values [idx_selected_splitting_values ]
465
+
466
+ return split_value
467
+
468
+
459
469
def draw_leaf_value (Y_mu_pred , mean , m , normal , kf ):
460
470
"""Draw Gaussian distributed leaf values."""
461
471
if Y_mu_pred .size == 0 :
0 commit comments