Skip to content

Commit 9f702e0

Browse files
committed
Added tests for split rules
1 parent bdb081a commit 9f702e0

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/test_split_rules.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
3+
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
4+
import pytest
5+
6+
@pytest.mark.parametrize(
7+
argnames="Rule",
8+
argvalues=[ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule],
9+
ids=["continuous", "one_hot", "subset"],
10+
)
11+
def test_split_rule(Rule):
12+
13+
# Should return None if only one available value to pick from
14+
assert Rule.get_split_value(np.zeros(1)) is None
15+
16+
# get_split should return a value divide can use
17+
available_values = np.arange(10).astype(float)
18+
sv = Rule.get_split_value(available_values)
19+
left = Rule.divide(available_values,sv)
20+
21+
# divide should return a boolean numpy array
22+
# This de facto ensures it is a binary split
23+
assert len(left) == len(available_values)
24+
assert left.dtype == 'bool'

0 commit comments

Comments
 (0)