Skip to content

Commit e5c95f0

Browse files
velochyaloctavodia
andauthored
Allow different splitting rules (#96)
* Added an option to specify different splitting rules for each dimension * Allow batch sampling in tree.predict, speeding up _sample_posterior by 60x * Addressed most issues flagged in PR comments * Fix the reference to flexBART Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * Added tests for split rules * Fix mypy annotations * Fix linter issues --------- Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
1 parent 160d0ae commit e5c95f0

File tree

7 files changed

+222
-49
lines changed

7 files changed

+222
-49
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pymc_bart.bart import BART
1717
from pymc_bart.pgbart import PGBART
18+
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
1819
from pymc_bart.utils import (
1920
plot_convergence,
2021
plot_pdp,

pymc_bart/bart.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from .tree import Tree
2929
from .utils import TensorLike, _sample_posterior
30+
from .split_rules import SplitRule
3031

3132
__all__ = ["BART"]
3233

@@ -91,6 +92,10 @@ class BART(Distribution):
9192
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
9293
1. Otherwise they will be normalized.
9394
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
95+
split_rules : Optional[SplitRule], default None
96+
List of SplitRule objects, one per column in input data.
97+
Allows using different split rules for different columns. Default is ContinuousSplitRule.
98+
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
9499
95100
Notes
96101
-----
@@ -109,6 +114,7 @@ def __new__(
109114
beta: float = 2.0,
110115
response: str = "constant",
111116
split_prior: Optional[List[float]] = None,
117+
split_rules: Optional[SplitRule] = None,
112118
**kwargs,
113119
):
114120
manager = Manager()
@@ -134,6 +140,7 @@ def __new__(
134140
alpha=alpha,
135141
beta=beta,
136142
split_prior=split_prior,
143+
split_rules=split_rules,
137144
),
138145
)()
139146

pymc_bart/pgbart.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from pymc_bart.bart import BARTRV
2828
from pymc_bart.tree import Node, Tree, get_idx_left_child, get_idx_right_child, get_depth
29+
from pymc_bart.split_rules import ContinuousSplitRule
2930

3031

3132
class ParticleTree:
@@ -141,6 +142,11 @@ def __init__(
141142
else:
142143
self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32)
143144

145+
if self.bart.split_rules:
146+
self.split_rules = self.bart.split_rules
147+
else:
148+
self.split_rules = [ContinuousSplitRule] * self.X.shape[1]
149+
144150
init_mean = self.bart.Y.mean()
145151
self.num_observations = self.X.shape[0]
146152
self.num_variates = self.X.shape[1]
@@ -164,6 +170,7 @@ def __init__(
164170
idx_data_points=np.arange(self.num_observations, dtype="int32"),
165171
num_observations=self.num_observations,
166172
shape=self.shape,
173+
split_rules=self.split_rules,
167174
)
168175

169176
self.normal = NormalSampler(1, self.shape)
@@ -443,13 +450,17 @@ def grow_tree(
443450
idx_data_points, available_splitting_values = filter_missing_values(
444451
X[idx_data_points, selected_predictor], idx_data_points, missing_data
445452
)
446-
split_value = get_split_value(available_splitting_values)
453+
454+
split_rule = tree.split_rules[selected_predictor]
455+
456+
split_value = split_rule.get_split_value(available_splitting_values)
447457

448458
if split_value is None:
449459
return None
450-
new_idx_data_points = get_new_idx_data_points(
451-
available_splitting_values, split_value, idx_data_points
452-
)
460+
461+
to_left = split_rule.divide(available_splitting_values, split_value)
462+
new_idx_data_points = idx_data_points[to_left], idx_data_points[~to_left]
463+
453464
current_node_children = (
454465
get_idx_left_child(index_leaf_node),
455466
get_idx_right_child(index_leaf_node),
@@ -481,12 +492,6 @@ def grow_tree(
481492
return current_node_children
482493

483494

484-
@njit
485-
def get_new_idx_data_points(available_splitting_values, split_value, idx_data_points):
486-
split_idx = available_splitting_values <= split_value
487-
return idx_data_points[split_idx], idx_data_points[~split_idx]
488-
489-
490495
def filter_missing_values(available_splitting_values, idx_data_points, missing_data):
491496
if missing_data:
492497
mask = ~np.isnan(available_splitting_values)
@@ -495,14 +500,6 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
495500
return idx_data_points, available_splitting_values
496501

497502

498-
def get_split_value(available_splitting_values):
499-
split_value = None
500-
if available_splitting_values.size > 0:
501-
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
502-
split_value = available_splitting_values[idx_selected_splitting_values]
503-
return split_value
504-
505-
506503
def draw_leaf_value(
507504
y_mu_pred: npt.NDArray[np.float_],
508505
x_mu: npt.NDArray[np.float_],

pymc_bart/split_rules.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import abstractmethod
16+
from numba import njit
17+
import numpy as np
18+
19+
20+
class SplitRule:
21+
"""
22+
Abstract template class for a split rule
23+
"""
24+
25+
@staticmethod
26+
@abstractmethod
27+
def get_split_value(available_splitting_values):
28+
pass
29+
30+
@staticmethod
31+
@abstractmethod
32+
def divide(available_splitting_values, split_value):
33+
pass
34+
35+
36+
class ContinuousSplitRule(SplitRule):
37+
"""
38+
Standard continuous split rule: pick a pivot value and split
39+
depending on if variable is smaller or greater than the value picked.
40+
"""
41+
42+
@staticmethod
43+
def get_split_value(available_splitting_values):
44+
split_value = None
45+
if available_splitting_values.size > 1:
46+
idx_selected_splitting_values = int(
47+
np.random.random() * len(available_splitting_values)
48+
)
49+
split_value = available_splitting_values[idx_selected_splitting_values]
50+
return split_value
51+
52+
@staticmethod
53+
@njit
54+
def divide(available_splitting_values, split_value):
55+
return available_splitting_values <= split_value
56+
57+
58+
class OneHotSplitRule(SplitRule):
59+
"""Choose a single categorical value and branch on if the variable is that value or not"""
60+
61+
@staticmethod
62+
def get_split_value(available_splitting_values):
63+
split_value = None
64+
if available_splitting_values.size > 1 and not np.all(
65+
available_splitting_values == available_splitting_values[0]
66+
):
67+
idx_selected_splitting_values = int(
68+
np.random.random() * len(available_splitting_values)
69+
)
70+
split_value = available_splitting_values[idx_selected_splitting_values]
71+
return split_value
72+
73+
@staticmethod
74+
@njit
75+
def divide(available_splitting_values, split_value):
76+
return available_splitting_values == split_value
77+
78+
79+
class SubsetSplitRule(SplitRule):
80+
"""
81+
Choose a random subset of the categorical values and branch on belonging to that set.
82+
This is the approach taken by Sameer K. Deshpande.
83+
flexBART: Flexible Bayesian regression trees with categorical predictors. arXiv,
84+
`link <https://arxiv.org/abs/2211.04459>`__
85+
"""
86+
87+
@staticmethod
88+
def get_split_value(available_splitting_values):
89+
split_value = None
90+
if available_splitting_values.size > 1 and not np.all(
91+
available_splitting_values == available_splitting_values[0]
92+
):
93+
unique_values = np.unique(available_splitting_values)
94+
while True:
95+
sample = np.random.randint(0, 2, size=len(unique_values)).astype(bool)
96+
if np.any(sample):
97+
break
98+
split_value = unique_values[sample]
99+
return split_value
100+
101+
@staticmethod
102+
def divide(available_splitting_values, split_value):
103+
return np.isin(available_splitting_values, split_value)

0 commit comments

Comments
 (0)