@@ -217,3 +217,29 @@ def test_bart_moment(size, expected):
217
217
with pm .Model () as model :
218
218
pmb .BART ("x" , X = X , Y = Y , size = size )
219
219
assert_moment_is_expected (model , expected )
220
+
221
+
222
+ @pytest .mark .parametrize (
223
+ argnames = "separate_trees" ,
224
+ argvalues = [False ,True ],
225
+ ids = ["shared-trees" , "separate-trees" ],
226
+ )
227
+ def test_categorical (separate_trees ):
228
+
229
+ Y = np .array ([
230
+ [1 ,1 ,1 ,0 ,0 ,0 ,0 ,0 ,0 ],
231
+ [0 ,0 ,0 ,1 ,1 ,1 ,0 ,0 ,0 ],
232
+ [0 ,0 ,0 ,0 ,0 ,0 ,1 ,1 ,1 ],
233
+ ]).T
234
+
235
+ X = np .concatenate ([Y [:,:2 ], np .random .randint (0 ,6 ,size = (9 ,4 ))],axis = 1 )
236
+
237
+ with pm .Model () as model :
238
+ w = pmb .BART ("w" , X , Y , m = 3 , shape = Y .T .shape , separate_trees = separate_trees )
239
+ y = pm .Multinomial ("y" , p = pm .math .softmax (w .T ,axis = - 1 ), n = 1 , observed = Y , shape = Y .shape )
240
+ idata = pm .sample (random_seed = 3415 , tune = 300 , draws = 300 )
241
+ idata = pm .sample_posterior_predictive (idata , predictions = True , extend_inferencedata = True )
242
+
243
+ probs = np .array (idata .predictions .y .mean (['chain' ,'draw' ]))[Y .astype (bool )]
244
+
245
+ assert (probs > 0.4 ).all ()
0 commit comments