Skip to content

Commit 53ae2fd

Browse files
committed
Made existing tests work and added a new test for separate_trees
1 parent 7c04127 commit 53ae2fd

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

pymc_bart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _sample_posterior(
4949
X = X.eval()
5050

5151
if size is None:
52-
size_iter: Union[List, Tuple] = ()
52+
size_iter: Union[List, Tuple] = (1,)
5353
elif isinstance(size, int):
5454
size_iter = [size]
5555
else:

tests/test_bart.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,29 @@ def test_bart_moment(size, expected):
217217
with pm.Model() as model:
218218
pmb.BART("x", X=X, Y=Y, size=size)
219219
assert_moment_is_expected(model, expected)
220+
221+
222+
@pytest.mark.parametrize(
223+
argnames="separate_trees",
224+
argvalues=[False,True],
225+
ids=["shared-trees", "separate-trees"],
226+
)
227+
def test_categorical(separate_trees):
228+
229+
Y = np.array([
230+
[1,1,1,0,0,0,0,0,0],
231+
[0,0,0,1,1,1,0,0,0],
232+
[0,0,0,0,0,0,1,1,1],
233+
]).T
234+
235+
X = np.concatenate([Y[:,:2], np.random.randint(0,6,size=(9,4))],axis=1)
236+
237+
with pm.Model() as model:
238+
w = pmb.BART("w", X, Y, m=3, shape=Y.T.shape, separate_trees=separate_trees)
239+
y = pm.Multinomial("y", p=pm.math.softmax(w.T,axis=-1), n=1, observed=Y, shape=Y.shape)
240+
idata = pm.sample(random_seed=3415, tune=300, draws=300)
241+
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
242+
243+
probs = np.array(idata.predictions.y.mean(['chain','draw']))[Y.astype(bool)]
244+
245+
assert (probs>0.4).all()

0 commit comments

Comments
 (0)