Skip to content

Commit 78439e6

Browse files
committed
Added tests for split rules
1 parent bdb081a commit 78439e6

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

pymc_bart/bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __new__(
143143
alpha=alpha,
144144
beta=beta,
145145
split_prior=split_prior,
146-
split_rules=split_rules
146+
split_rules=split_rules,
147147
),
148148
)()
149149

pymc_bart/split_rules.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def get_split_value(available_splitting_values):
8787
if available_splitting_values.size > 1 and not np.all(
8888
available_splitting_values == available_splitting_values[0]
8989
):
90-
unique_values = np.unique(available_splitting_values)[
91-
:-1
92-
] # Remove last one so it always goes to left
90+
unique_values = np.unique(available_splitting_values)
9391
while True:
9492
sample = np.random.randint(0, 2, size=len(unique_values)).astype(bool)
9593
if np.any(sample):

tests/test_split_rules.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
3+
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
4+
import pytest
5+
6+
7+
@pytest.mark.parametrize(
8+
argnames="Rule",
9+
argvalues=[ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule],
10+
ids=["continuous", "one_hot", "subset"],
11+
)
12+
def test_split_rule(Rule):
13+
14+
# Should return None if only one available value to pick from
15+
assert Rule.get_split_value(np.zeros(1)) is None
16+
17+
# get_split should return a value divide can use
18+
available_values = np.arange(10).astype(float)
19+
sv = Rule.get_split_value(available_values)
20+
left = Rule.divide(available_values, sv)
21+
22+
# divide should return a boolean numpy array
23+
# This de facto ensures it is a binary split
24+
assert len(left) == len(available_values)
25+
assert left.dtype == "bool"
26+
27+
# divide should be deterministic
28+
left_repeated = Rule.divide(available_values, sv)
29+
assert (left == left_repeated).all()
30+
31+
# Most elements should have a chance to go either direction
32+
# NB! This is not 100% necessary, but is a good proxy
33+
probs = np.array(
34+
[
35+
Rule.divide(available_values, Rule.get_split_value(available_values))
36+
for _ in range(10000)
37+
]
38+
).mean(axis=0)
39+
40+
assert (probs > 0.01).sum() >= len(available_values) - 1
41+
assert (probs < 0.99).sum() >= len(available_values) - 1

0 commit comments

Comments
 (0)