Skip to content

Commit 825d1d9

Browse files
committed
Made existing tests work and added a new test for separate_trees
1 parent 7c04127 commit 825d1d9

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

pymc_bart/pgbart.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def __init__(
175175
leaf_node_value=init_mean / self.m,
176176
idx_data_points=np.arange(self.num_observations, dtype="int32"),
177177
num_observations=self.num_observations,
178-
shape=self.leaves_shape
178+
shape=self.leaves_shape,
179179
)
180180

181181
self.normal = NormalSampler(1, self.leaves_shape)
@@ -357,7 +357,8 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
357357
"""
358358

359359
delta = (
360-
np.identity(self.trees_shape)[odim][:, None, None] * particle.tree._predict()[None, :, :]
360+
np.identity(self.trees_shape)[odim][:, None, None]
361+
* particle.tree._predict()[None, :, :]
361362
)
362363

363364
new_likelihood = self.likelihood_logp((self.sum_trees_noi + delta).flatten())

pymc_bart/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _sample_posterior(
4949
X = X.eval()
5050

5151
if size is None:
52-
size_iter: Union[List, Tuple] = ()
52+
size_iter: Union[List, Tuple] = (1,)
5353
elif isinstance(size, int):
5454
size_iter = [size]
5555
else:
@@ -69,7 +69,9 @@ def _sample_posterior(
6969
for ind, p in enumerate(pred):
7070
for oi, odim_trees in enumerate(stacked_trees[idx[ind]]):
7171
for tree in odim_trees:
72-
p[oi] += np.vstack([tree.predict(x=x, excluded=excluded, shape=tree_shape) for x in X]).T
72+
p[oi] += np.vstack(
73+
[tree.predict(x=x, excluded=excluded, shape=tree_shape) for x in X]
74+
).T
7375

7476
# pred.reshape((*size_iter, shape, -1))
7577
return pred.transpose((0, 3, 1, 2)).reshape((*size_iter, -1, shape))

tests/test_bart.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,23 @@ def test_bart_moment(size, expected):
217217
with pm.Model() as model:
218218
pmb.BART("x", X=X, Y=Y, size=size)
219219
assert_moment_is_expected(model, expected)
220+
221+
222+
@pytest.mark.parametrize(
223+
argnames="separate_trees",
224+
argvalues=[False, True],
225+
ids=["shared-trees", "separate-trees"],
226+
)
227+
def test_categorical(separate_trees):
228+
229+
Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
230+
X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1)
231+
232+
with pm.Model() as model:
233+
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9), separate_trees=separate_trees)
234+
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
235+
idata = pm.sample(random_seed=3415, tune=300, draws=300)
236+
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
237+
238+
# Fit should be good enough so right category is selected over 50% of time
239+
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()

0 commit comments

Comments
 (0)