Skip to content

Commit 73f5129

Browse files
committed
Made existing tests work and added a new test for separate_trees
1 parent 7c04127 commit 73f5129

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

pymc_bart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _sample_posterior(
4949
X = X.eval()
5050

5151
if size is None:
52-
size_iter: Union[List, Tuple] = ()
52+
size_iter: Union[List, Tuple] = (1,)
5353
elif isinstance(size, int):
5454
size_iter = [size]
5555
else:

tests/test_bart.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,23 @@ def test_bart_moment(size, expected):
217217
with pm.Model() as model:
218218
pmb.BART("x", X=X, Y=Y, size=size)
219219
assert_moment_is_expected(model, expected)
220+
221+
222+
@pytest.mark.parametrize(
223+
argnames="separate_trees",
224+
argvalues=[False, True],
225+
ids=["shared-trees", "separate-trees"],
226+
)
227+
def test_categorical(separate_trees):
228+
229+
Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
230+
X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1)
231+
232+
with pm.Model() as model:
233+
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9), separate_trees=separate_trees)
234+
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
235+
idata = pm.sample(random_seed=3415, tune=300, draws=300)
236+
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
237+
238+
# Fit should be good enough so right category is selected over 50% of time
239+
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()

0 commit comments

Comments
 (0)