21
21
def _sample_posterior (
22
22
all_trees : List [List [Tree ]],
23
23
X : TensorLike ,
24
+ m : int ,
24
25
rng : np .random .Generator ,
25
26
size : Optional [Union [int , Tuple [int , ...]]] = None ,
26
27
excluded : Optional [npt .NDArray [np .int_ ]] = None ,
@@ -35,6 +36,8 @@ def _sample_posterior(
35
36
X : tensor-like
36
37
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
37
38
out-of-sample predictions.
39
+ m : int
40
+ Number of trees
38
41
rng : NumPy RandomGenerator
39
42
size : int or tuple
40
43
Number of samples.
@@ -57,7 +60,7 @@ def _sample_posterior(
57
60
flatten_size *= s
58
61
59
62
idx = rng .integers (0 , len (stacked_trees ), size = flatten_size )
60
- shape = stacked_trees [0 ][0 ].predict (X [0 ]).size
63
+ shape = stacked_trees [0 ][0 ].predict (x = X [0 ], m = m ).size
61
64
62
65
pred = np .zeros ((flatten_size , X .shape [0 ], shape ))
63
66
@@ -220,6 +223,8 @@ def plot_dependence(
220
223
-------
221
224
axes: matplotlib axes
222
225
"""
226
+ m : int = bartrv .owner .op .m
227
+
223
228
if kind not in ["pdp" , "ice" ]:
224
229
raise ValueError (f"kind={ kind } is not suported. Available option are 'pdp' or 'ice'" )
225
230
@@ -294,15 +299,15 @@ def plot_dependence(
294
299
new_X [:, indices_mi ] = X [:, indices_mi ]
295
300
new_X [:, i ] = x_i
296
301
y_pred .append (
297
- np .mean (_sample_posterior (all_trees , X = new_X , rng = rng , size = samples ), 1 )
302
+ np .mean (_sample_posterior (all_trees , X = new_X , m = m , rng = rng , size = samples ), 1 )
298
303
)
299
304
new_x_target .append (new_x_i )
300
305
else :
301
306
for instance in instances_ary :
302
307
new_X = X [idx_s ]
303
308
new_X [:, indices_mi ] = X [:, indices_mi ][instance ]
304
309
y_pred .append (
305
- np .mean (_sample_posterior (all_trees , X = new_X , rng = rng , size = samples ), 0 )
310
+ np .mean (_sample_posterior (all_trees , X = new_X , m = m , rng = rng , size = samples ), 0 )
306
311
)
307
312
new_x_target .append (new_X [:, i ])
308
313
y_mins .append (np .min (y_pred ))
@@ -445,6 +450,8 @@ def plot_variable_importance(
445
450
"""
446
451
_ , axes = plt .subplots (2 , 1 , figsize = figsize )
447
452
453
+ m : int = bartrv .owner .op .m
454
+
448
455
if hasattr (X , "columns" ) and hasattr (X , "values" ):
449
456
labels = X .columns
450
457
X = X .values
@@ -474,13 +481,13 @@ def plot_variable_importance(
474
481
475
482
all_trees = bartrv .owner .op .all_trees
476
483
477
- predicted_all = _sample_posterior (all_trees , X = X , rng = rng , size = samples , excluded = None )
484
+ predicted_all = _sample_posterior (all_trees , X = X , m = m , rng = rng , size = samples , excluded = None )
478
485
479
486
ev_mean = np .zeros (len (var_imp ))
480
487
ev_hdi = np .zeros ((len (var_imp ), 2 ))
481
488
for idx , subset in enumerate (subsets ):
482
489
predicted_subset = _sample_posterior (
483
- all_trees = all_trees , X = X , rng = rng , size = samples , excluded = subset
490
+ all_trees = all_trees , X = X , m = m , rng = rng , size = samples , excluded = subset
484
491
)
485
492
pearson = np .zeros (samples )
486
493
for j in range (samples ):
0 commit comments