@@ -53,9 +53,11 @@ def sample_tree_sequential(
53
53
missing_data ,
54
54
sum_trees_output ,
55
55
mean ,
56
+ linear_fit ,
56
57
m ,
57
58
normal ,
58
59
mu_std ,
60
+ response ,
59
61
):
60
62
tree_grew = False
61
63
if self .expansion_nodes :
@@ -73,9 +75,11 @@ def sample_tree_sequential(
73
75
missing_data ,
74
76
sum_trees_output ,
75
77
mean ,
78
+ linear_fit ,
76
79
m ,
77
80
normal ,
78
81
mu_std ,
82
+ response ,
79
83
)
80
84
if tree_grew :
81
85
new_indexes = self .tree .idx_leaf_nodes [- 2 :]
@@ -97,11 +101,17 @@ class PGBART(ArrayStepShared):
97
101
Number of particles for the conditional SMC sampler. Defaults to 10
98
102
max_stages : int
99
103
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
100
- chunk = int
101
- Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees.
104
+ batch : int
105
+ Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
106
+ during tuning and 20% after tuning.
102
107
model: PyMC Model
103
108
Optional model for sampling step. Defaults to None (taken from context).
104
109
110
+ Note
111
+ ----
112
+ This sampler is inspired by the [Lakshminarayanan2015] Particle Gibbs sampler, but introduces
113
+ several changes. The changes will be properly documented soon.
114
+
105
115
References
106
116
----------
107
117
.. [Lakshminarayanan2015] Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015),
@@ -114,7 +124,7 @@ class PGBART(ArrayStepShared):
114
124
generates_stats = True
115
125
stats_dtypes = [{"variable_inclusion" : np .ndarray }]
116
126
117
- def __init__ (self , vars = None , num_particles = 10 , max_stages = 100 , chunk = "auto" , model = None ):
127
+ def __init__ (self , vars = None , num_particles = 10 , max_stages = 100 , batch = "auto" , model = None ):
118
128
_log .warning ("BART is experimental. Use with caution." )
119
129
model = modelcontext (model )
120
130
initial_values = model .initial_point
@@ -125,6 +135,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
125
135
self .m = self .bart .m
126
136
self .alpha = self .bart .alpha
127
137
self .k = self .bart .k
138
+ self .response = self .bart .response
128
139
self .split_prior = self .bart .split_prior
129
140
if self .split_prior is None :
130
141
self .split_prior = np .ones (self .X .shape [1 ])
@@ -149,6 +160,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
149
160
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
150
161
)
151
162
self .mean = fast_mean ()
163
+ self .linear_fit = fast_linear_fit ()
164
+
152
165
self .normal = NormalSampler ()
153
166
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
154
167
self .ssv = SampleSplittingVariable (self .split_prior )
@@ -157,10 +170,10 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
157
170
self .idx = 0
158
171
self .iter = 0
159
172
self .sum_trees = []
160
- self .chunk = chunk
173
+ self .batch = batch
161
174
162
- if self .chunk == "auto" :
163
- self .chunk = max (1 , int (self .m * 0.1 ))
175
+ if self .batch == "auto" :
176
+ self .batch = max (1 , int (self .m * 0.1 ))
164
177
self .log_num_particles = np .log (num_particles )
165
178
self .indices = list (range (1 , num_particles ))
166
179
self .len_indices = len (self .indices )
@@ -190,7 +203,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
190
203
if self .idx == self .m :
191
204
self .idx = 0
192
205
193
- for tree_id in range (self .idx , self .idx + self .chunk ):
206
+ for tree_id in range (self .idx , self .idx + self .batch ):
194
207
if tree_id >= self .m :
195
208
break
196
209
# Generate an initial set of SMC particles
@@ -213,9 +226,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
213
226
self .missing_data ,
214
227
sum_trees_output ,
215
228
self .mean ,
229
+ self .linear_fit ,
216
230
self .m ,
217
231
self .normal ,
218
232
self .mu_std ,
233
+ self .response ,
219
234
)
220
235
if tree_grew :
221
236
self .update_weight (p )
@@ -251,6 +266,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
251
266
self .split_prior [index ] += 1
252
267
self .ssv = SampleSplittingVariable (self .split_prior )
253
268
else :
269
+ self .batch = max (1 , int (self .m * 0.2 ))
254
270
self .iter += 1
255
271
self .sum_trees .append (new_tree )
256
272
if not self .iter % self .m :
@@ -389,16 +405,20 @@ def grow_tree(
389
405
missing_data ,
390
406
sum_trees_output ,
391
407
mean ,
408
+ linear_fit ,
392
409
m ,
393
410
normal ,
394
411
mu_std ,
412
+ response ,
395
413
):
396
414
current_node = tree .get_node (index_leaf_node )
415
+ idx_data_points = current_node .idx_data_points
397
416
398
417
index_selected_predictor = ssv .rvs ()
399
418
selected_predictor = available_predictors [index_selected_predictor ]
400
- available_splitting_values = X [current_node . idx_data_points , selected_predictor ]
419
+ available_splitting_values = X [idx_data_points , selected_predictor ]
401
420
if missing_data :
421
+ idx_data_points = idx_data_points [~ np .isnan (available_splitting_values )]
402
422
available_splitting_values = available_splitting_values [
403
423
~ np .isnan (available_splitting_values )
404
424
]
@@ -407,58 +427,82 @@ def grow_tree(
407
427
return False , None
408
428
409
429
idx_selected_splitting_values = discrete_uniform_sampler (len (available_splitting_values ))
410
- selected_splitting_rule = available_splitting_values [idx_selected_splitting_values ]
430
+ split_value = available_splitting_values [idx_selected_splitting_values ]
411
431
new_split_node = SplitNode (
412
432
index = index_leaf_node ,
413
433
idx_split_variable = selected_predictor ,
414
- split_value = selected_splitting_rule ,
434
+ split_value = split_value ,
415
435
)
416
436
417
437
left_node_idx_data_points , right_node_idx_data_points = get_new_idx_data_points (
418
- new_split_node , current_node . idx_data_points , X
438
+ split_value , idx_data_points , selected_predictor , X
419
439
)
420
440
421
- left_node_value = draw_leaf_value (
422
- sum_trees_output [left_node_idx_data_points ], mean , m , normal , mu_std
441
+ if response == "mix" :
442
+ response = "linear" if np .random .random () >= 0.5 else "constant"
443
+
444
+ left_node_value , left_node_linear_params = draw_leaf_value (
445
+ sum_trees_output [left_node_idx_data_points ],
446
+ X [left_node_idx_data_points , selected_predictor ],
447
+ mean ,
448
+ linear_fit ,
449
+ m ,
450
+ normal ,
451
+ mu_std ,
452
+ response ,
423
453
)
424
- right_node_value = draw_leaf_value (
425
- sum_trees_output [right_node_idx_data_points ], mean , m , normal , mu_std
454
+ right_node_value , right_node_linear_params = draw_leaf_value (
455
+ sum_trees_output [right_node_idx_data_points ],
456
+ X [right_node_idx_data_points , selected_predictor ],
457
+ mean ,
458
+ linear_fit ,
459
+ m ,
460
+ normal ,
461
+ mu_std ,
462
+ response ,
426
463
)
427
464
428
465
new_left_node = LeafNode (
429
466
index = current_node .get_idx_left_child (),
430
467
value = left_node_value ,
431
468
idx_data_points = left_node_idx_data_points ,
469
+ linear_params = left_node_linear_params ,
432
470
)
433
471
new_right_node = LeafNode (
434
472
index = current_node .get_idx_right_child (),
435
473
value = right_node_value ,
436
474
idx_data_points = right_node_idx_data_points ,
475
+ linear_params = right_node_linear_params ,
437
476
)
438
477
tree .grow_tree (index_leaf_node , new_split_node , new_left_node , new_right_node )
439
478
440
479
return True , index_selected_predictor
441
480
442
481
443
- def get_new_idx_data_points (current_split_node , idx_data_points , X ):
444
- idx_split_variable = current_split_node .idx_split_variable
445
- split_value = current_split_node .split_value
482
+ def get_new_idx_data_points (split_value , idx_data_points , selected_predictor , X ):
446
483
447
- left_idx = X [idx_data_points , idx_split_variable ] <= split_value
484
+ left_idx = X [idx_data_points , selected_predictor ] <= split_value
448
485
left_node_idx_data_points = idx_data_points [left_idx ]
449
486
right_node_idx_data_points = idx_data_points [~ left_idx ]
450
487
451
488
return left_node_idx_data_points , right_node_idx_data_points
452
489
453
490
454
- def draw_leaf_value (sum_trees_output_idx , mean , m , normal , mu_std ):
491
+ def draw_leaf_value (Y_mu_pred , X_mu , mean , linear_fit , m , normal , mu_std , response ):
455
492
"""Draw Gaussian distributed leaf values"""
456
- if sum_trees_output_idx .size == 0 :
457
- return 0
493
+ linear_params = None
494
+ if Y_mu_pred .size == 0 :
495
+ return 0 , linear_params
496
+ elif Y_mu_pred .size == 1 :
497
+ mu_mean = Y_mu_pred .item () / m
458
498
else :
459
- mu_mean = mean (sum_trees_output_idx ) / m
460
- draw = normal .random () * mu_std + mu_mean
461
- return draw
499
+ if response == "constant" :
500
+ mu_mean = mean (Y_mu_pred ) / m
501
+ elif response == "linear" :
502
+ Y_fit , linear_params = linear_fit (X_mu , Y_mu_pred )
503
+ mu_mean = Y_fit / m
504
+ draw = normal .random () * mu_std + mu_mean
505
+ return draw , linear_params
462
506
463
507
464
508
def fast_mean ():
@@ -479,6 +523,29 @@ def mean(a):
479
523
return mean
480
524
481
525
526
+ def fast_linear_fit ():
527
+ """If available use Numba to speed up the computation of the linear fit"""
528
+
529
+ def linear_fit (X , Y ):
530
+
531
+ n = len (Y )
532
+ xbar = np .sum (X ) / n
533
+ ybar = np .sum (Y ) / n
534
+
535
+ b = (X @ Y - n * xbar * ybar ) / (X @ X - n * xbar ** 2 )
536
+ a = ybar - b * xbar
537
+
538
+ Y_fit = a + b * X
539
+ return Y_fit , (a , b )
540
+
541
+ try :
542
+ from numba import jit
543
+
544
+ return jit (linear_fit )
545
+ except ImportError :
546
+ return linear_fit
547
+
548
+
482
549
def discrete_uniform_sampler (upper_value ):
483
550
"""Draw from the uniform distribution with bounds [0, upper_value).
484
551
0 commit comments