@@ -217,3 +217,23 @@ def test_bart_moment(size, expected):
217
217
with pm .Model () as model :
218
218
pmb .BART ("x" , X = X , Y = Y , size = size )
219
219
assert_moment_is_expected (model , expected )
220
+
221
+
222
+ @pytest .mark .parametrize (
223
+ argnames = "separate_trees" ,
224
+ argvalues = [False , True ],
225
+ ids = ["shared-trees" , "separate-trees" ],
226
+ )
227
+ def test_categorical (separate_trees ):
228
+
229
+ Y = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 , 2 ])
230
+ X = np .concatenate ([Y [:, None ], np .random .randint (0 , 6 , size = (9 , 4 ))], axis = 1 )
231
+
232
+ with pm .Model () as model :
233
+ lo = pmb .BART ("logodds" , X , Y , m = 2 , shape = (3 , 9 ), separate_trees = separate_trees )
234
+ y = pm .Categorical ("y" , p = pm .math .softmax (lo .T , axis = - 1 ), observed = Y )
235
+ idata = pm .sample (random_seed = 3415 , tune = 300 , draws = 300 )
236
+ idata = pm .sample_posterior_predictive (idata , predictions = True , extend_inferencedata = True )
237
+
238
+ # Fit should be good enough so right category is selected over 50% of time
239
+ assert (idata .predictions .y .median (["chain" , "draw" ]) == Y ).all ()
0 commit comments