Skip to content

Commit 488f9ca

Browse files
committed
Rename variables to comply with black and improve the test to include split_rules
1 parent 5cc033c commit 488f9ca

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

pymc_bart/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ def _sample_posterior(
6161

6262
idx = rng.integers(0, len(stacked_trees), size=flatten_size)
6363

64-
st_shape = len(stacked_trees[0])
65-
tree_shape = shape // st_shape
64+
trees_shape = len(stacked_trees[0])
65+
leaves_shape = shape // trees_shape
6666

67-
pred = np.zeros((flatten_size, st_shape, tree_shape, X.shape[0]))
67+
pred = np.zeros((flatten_size, trees_shape, leaves_shape, X.shape[0]))
6868

6969
for ind, p in enumerate(pred):
7070
for odim, odim_trees in enumerate(stacked_trees[idx[ind]]):
7171
for tree in odim_trees:
72-
p[odim] += tree.predict(x=X, excluded=excluded, shape=tree_shape)
72+
p[odim] += tree.predict(x=X, excluded=excluded, shape=leaves_shape)
7373

7474
# pred.reshape((*size_iter, shape, -1))
7575
return pred.transpose((0, 3, 1, 2)).reshape((*size_iter, -1, shape))

tests/test_bart.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,17 +220,24 @@ def test_bart_moment(size, expected):
220220

221221

222222
@pytest.mark.parametrize(
223-
argnames="separate_trees",
224-
argvalues=[False, True],
225-
ids=["shared-trees", "separate-trees"],
223+
argnames="separate_trees,split_rule",
224+
argvalues=[
225+
(False,pmb.ContinuousSplitRule),
226+
(False,pmb.OneHotSplitRule),
227+
(False,pmb.SubsetSplitRule),
228+
(True,pmb.ContinuousSplitRule)
229+
],
230+
ids=["continuous", "one-hot", "subset", "separate-trees"],
226231
)
227-
def test_categorical(separate_trees):
232+
def test_categorical_model(separate_trees,split_rule):
228233

229234
Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
230235
X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1)
231236

232237
with pm.Model() as model:
233-
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9), separate_trees=separate_trees)
238+
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9),
239+
split_rules=[split_rule]*5,
240+
separate_trees=separate_trees)
234241
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
235242
idata = pm.sample(random_seed=3415, tune=300, draws=300)
236243
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)

0 commit comments

Comments
 (0)