@@ -220,17 +220,24 @@ def test_bart_moment(size, expected):
220
220
221
221
222
222
@pytest .mark .parametrize (
223
- argnames = "separate_trees" ,
224
- argvalues = [False , True ],
225
- ids = ["shared-trees" , "separate-trees" ],
223
+ argnames = "separate_trees,split_rule" ,
224
+ argvalues = [
225
+ (False ,pmb .ContinuousSplitRule ),
226
+ (False ,pmb .OneHotSplitRule ),
227
+ (False ,pmb .SubsetSplitRule ),
228
+ (True ,pmb .ContinuousSplitRule )
229
+ ],
230
+ ids = ["continuous" , "one-hot" , "subset" , "separate-trees" ],
226
231
)
227
- def test_categorical (separate_trees ):
232
+ def test_categorical_model (separate_trees , split_rule ):
228
233
229
234
Y = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 , 2 ])
230
235
X = np .concatenate ([Y [:, None ], np .random .randint (0 , 6 , size = (9 , 4 ))], axis = 1 )
231
236
232
237
with pm .Model () as model :
233
- lo = pmb .BART ("logodds" , X , Y , m = 2 , shape = (3 , 9 ), separate_trees = separate_trees )
238
+ lo = pmb .BART ("logodds" , X , Y , m = 2 , shape = (3 , 9 ),
239
+ split_rules = [split_rule ]* 5 ,
240
+ separate_trees = separate_trees )
234
241
y = pm .Categorical ("y" , p = pm .math .softmax (lo .T , axis = - 1 ), observed = Y )
235
242
idata = pm .sample (random_seed = 3415 , tune = 300 , draws = 300 )
236
243
idata = pm .sample_posterior_predictive (idata , predictions = True , extend_inferencedata = True )
0 commit comments