Skip to content

Allow different splitting rules #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pymc_bart.bart import BART
from pymc_bart.pgbart import PGBART
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
from pymc_bart.utils import (
plot_convergence,
plot_pdp,
Expand Down
7 changes: 7 additions & 0 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from .tree import Tree
from .utils import TensorLike, _sample_posterior
from .split_rules import SplitRule

__all__ = ["BART"]

Expand Down Expand Up @@ -91,6 +92,10 @@ class BART(Distribution):
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
split_rules : Optional[SplitRule], default None
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.

Notes
-----
Expand All @@ -109,6 +114,7 @@ def __new__(
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[List[float]] = None,
split_rules: Optional[SplitRule] = None,
**kwargs,
):
manager = Manager()
Expand All @@ -134,6 +140,7 @@ def __new__(
alpha=alpha,
beta=beta,
split_prior=split_prior,
split_rules=split_rules,
),
)()

Expand Down
33 changes: 15 additions & 18 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pymc_bart.bart import BARTRV
from pymc_bart.tree import Node, Tree, get_idx_left_child, get_idx_right_child, get_depth
from pymc_bart.split_rules import ContinuousSplitRule


class ParticleTree:
Expand Down Expand Up @@ -141,6 +142,11 @@ def __init__(
else:
self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32)

if self.bart.split_rules:
self.split_rules = self.bart.split_rules
else:
self.split_rules = [ContinuousSplitRule] * self.X.shape[1]

init_mean = self.bart.Y.mean()
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
Expand All @@ -164,6 +170,7 @@ def __init__(
idx_data_points=np.arange(self.num_observations, dtype="int32"),
num_observations=self.num_observations,
shape=self.shape,
split_rules=self.split_rules,
)

self.normal = NormalSampler(1, self.shape)
Expand Down Expand Up @@ -443,13 +450,17 @@ def grow_tree(
idx_data_points, available_splitting_values = filter_missing_values(
X[idx_data_points, selected_predictor], idx_data_points, missing_data
)
split_value = get_split_value(available_splitting_values)

split_rule = tree.split_rules[selected_predictor]

split_value = split_rule.get_split_value(available_splitting_values)

if split_value is None:
return None
new_idx_data_points = get_new_idx_data_points(
available_splitting_values, split_value, idx_data_points
)

to_left = split_rule.divide(available_splitting_values, split_value)
new_idx_data_points = idx_data_points[to_left], idx_data_points[~to_left]

current_node_children = (
get_idx_left_child(index_leaf_node),
get_idx_right_child(index_leaf_node),
Expand Down Expand Up @@ -481,12 +492,6 @@ def grow_tree(
return current_node_children


@njit
def get_new_idx_data_points(available_splitting_values, split_value, idx_data_points):
split_idx = available_splitting_values <= split_value
return idx_data_points[split_idx], idx_data_points[~split_idx]


def filter_missing_values(available_splitting_values, idx_data_points, missing_data):
if missing_data:
mask = ~np.isnan(available_splitting_values)
Expand All @@ -495,14 +500,6 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
return idx_data_points, available_splitting_values


def get_split_value(available_splitting_values):
split_value = None
if available_splitting_values.size > 0:
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
split_value = available_splitting_values[idx_selected_splitting_values]
return split_value


def draw_leaf_value(
y_mu_pred: npt.NDArray[np.float_],
x_mu: npt.NDArray[np.float_],
Expand Down
103 changes: 103 additions & 0 deletions pymc_bart/split_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 The PyMC Developers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests for this module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try. Don't have much experience with pytest but I can hopefully figure it out :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests in https://github.com/pymc-devs/pymc-bart/blob/main/tests/test_tree.py are probably the simplest ones in PyMC-BART, you can try following a similar pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from numba import njit
import numpy as np


class SplitRule:
"""
Abstract template class for a split rule
"""

@staticmethod
@abstractmethod
def get_split_value(available_splitting_values):
pass

@staticmethod
@abstractmethod
def divide(available_splitting_values, split_value):
pass


class ContinuousSplitRule(SplitRule):
"""
Standard continuous split rule: pick a pivot value and split
depending on if variable is smaller or greater than the value picked.
"""

@staticmethod
def get_split_value(available_splitting_values):
split_value = None
if available_splitting_values.size > 1:
idx_selected_splitting_values = int(
np.random.random() * len(available_splitting_values)
)
split_value = available_splitting_values[idx_selected_splitting_values]
return split_value

@staticmethod
@njit
def divide(available_splitting_values, split_value):
return available_splitting_values <= split_value


class OneHotSplitRule(SplitRule):
"""Choose a single categorical value and branch on if the variable is that value or not"""

@staticmethod
def get_split_value(available_splitting_values):
split_value = None
if available_splitting_values.size > 1 and not np.all(
available_splitting_values == available_splitting_values[0]
):
idx_selected_splitting_values = int(
np.random.random() * len(available_splitting_values)
)
split_value = available_splitting_values[idx_selected_splitting_values]
return split_value

@staticmethod
@njit
def divide(available_splitting_values, split_value):
return available_splitting_values == split_value


class SubsetSplitRule(SplitRule):
"""
Choose a random subset of the categorical values and branch on belonging to that set.
This is the approach taken by Sameer K. Deshpande.
flexBART: Flexible Bayesian regression trees with categorical predictors. arXiv,
`link <https://arxiv.org/abs/2211.04459>`__
"""

@staticmethod
def get_split_value(available_splitting_values):
split_value = None
if available_splitting_values.size > 1 and not np.all(
available_splitting_values == available_splitting_values[0]
):
unique_values = np.unique(available_splitting_values)
while True:
sample = np.random.randint(0, 2, size=len(unique_values)).astype(bool)
if np.any(sample):
break
split_value = unique_values[sample]
return split_value

@staticmethod
def divide(available_splitting_values, split_value):
return np.isin(available_splitting_values, split_value)
Loading