diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 5b5e624..731e19e 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -15,6 +15,7 @@ from pymc_bart.bart import BART from pymc_bart.pgbart import PGBART +from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule from pymc_bart.utils import ( plot_convergence, plot_pdp, diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index ac530d6..2b3a85e 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -27,6 +27,7 @@ from .tree import Tree from .utils import TensorLike, _sample_posterior +from .split_rules import SplitRule __all__ = ["BART"] @@ -91,6 +92,10 @@ class BART(Distribution): Each element of split_prior should be in the [0, 1] interval and the elements should sum to 1. Otherwise they will be normalized. Defaults to 0, i.e. all covariates have the same prior probability to be selected. + split_rules : Optional[SplitRule], default None + List of SplitRule objects, one per column in input data. + Allows using different split rules for different columns. Default is ContinuousSplitRule. + Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. Notes ----- @@ -109,6 +114,7 @@ def __new__( beta: float = 2.0, response: str = "constant", split_prior: Optional[List[float]] = None, + split_rules: Optional[SplitRule] = None, **kwargs, ): manager = Manager() @@ -134,6 +140,7 @@ def __new__( alpha=alpha, beta=beta, split_prior=split_prior, + split_rules=split_rules, ), )() diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 8d07d5f..b49900e 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -26,6 +26,7 @@ from pymc_bart.bart import BARTRV from pymc_bart.tree import Node, Tree, get_idx_left_child, get_idx_right_child, get_depth +from pymc_bart.split_rules import ContinuousSplitRule class ParticleTree: @@ -141,6 +142,11 @@ def __init__( else: self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32) + if self.bart.split_rules: + self.split_rules = self.bart.split_rules + else: + self.split_rules = [ContinuousSplitRule] * self.X.shape[1] + init_mean = self.bart.Y.mean() self.num_observations = self.X.shape[0] self.num_variates = self.X.shape[1] @@ -164,6 +170,7 @@ def __init__( idx_data_points=np.arange(self.num_observations, dtype="int32"), num_observations=self.num_observations, shape=self.shape, + split_rules=self.split_rules, ) self.normal = NormalSampler(1, self.shape) @@ -443,13 +450,17 @@ def grow_tree( idx_data_points, available_splitting_values = filter_missing_values( X[idx_data_points, selected_predictor], idx_data_points, missing_data ) - split_value = get_split_value(available_splitting_values) + + split_rule = tree.split_rules[selected_predictor] + + split_value = split_rule.get_split_value(available_splitting_values) if split_value is None: return None - new_idx_data_points = get_new_idx_data_points( - available_splitting_values, split_value, idx_data_points - ) + + to_left = split_rule.divide(available_splitting_values, split_value) + new_idx_data_points = idx_data_points[to_left], idx_data_points[~to_left] + current_node_children = ( get_idx_left_child(index_leaf_node), get_idx_right_child(index_leaf_node), @@ -481,12 +492,6 @@ def grow_tree( return current_node_children -@njit -def get_new_idx_data_points(available_splitting_values, split_value, idx_data_points): - split_idx = available_splitting_values <= split_value - return idx_data_points[split_idx], idx_data_points[~split_idx] - - def filter_missing_values(available_splitting_values, idx_data_points, missing_data): if missing_data: mask = ~np.isnan(available_splitting_values) @@ -495,14 +500,6 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d return idx_data_points, available_splitting_values -def get_split_value(available_splitting_values): - split_value = None - if available_splitting_values.size > 0: - idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) - split_value = available_splitting_values[idx_selected_splitting_values] - return split_value - - def draw_leaf_value( y_mu_pred: npt.NDArray[np.float_], x_mu: npt.NDArray[np.float_], diff --git a/pymc_bart/split_rules.py b/pymc_bart/split_rules.py new file mode 100644 index 0000000..5a0d0cc --- /dev/null +++ b/pymc_bart/split_rules.py @@ -0,0 +1,103 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from numba import njit +import numpy as np + + +class SplitRule: + """ + Abstract template class for a split rule + """ + + @staticmethod + @abstractmethod + def get_split_value(available_splitting_values): + pass + + @staticmethod + @abstractmethod + def divide(available_splitting_values, split_value): + pass + + +class ContinuousSplitRule(SplitRule): + """ + Standard continuous split rule: pick a pivot value and split + depending on if variable is smaller or greater than the value picked. + """ + + @staticmethod + def get_split_value(available_splitting_values): + split_value = None + if available_splitting_values.size > 1: + idx_selected_splitting_values = int( + np.random.random() * len(available_splitting_values) + ) + split_value = available_splitting_values[idx_selected_splitting_values] + return split_value + + @staticmethod + @njit + def divide(available_splitting_values, split_value): + return available_splitting_values <= split_value + + +class OneHotSplitRule(SplitRule): + """Choose a single categorical value and branch on if the variable is that value or not""" + + @staticmethod + def get_split_value(available_splitting_values): + split_value = None + if available_splitting_values.size > 1 and not np.all( + available_splitting_values == available_splitting_values[0] + ): + idx_selected_splitting_values = int( + np.random.random() * len(available_splitting_values) + ) + split_value = available_splitting_values[idx_selected_splitting_values] + return split_value + + @staticmethod + @njit + def divide(available_splitting_values, split_value): + return available_splitting_values == split_value + + +class SubsetSplitRule(SplitRule): + """ + Choose a random subset of the categorical values and branch on belonging to that set. + This is the approach taken by Sameer K. Deshpande. + flexBART: Flexible Bayesian regression trees with categorical predictors. arXiv, + `link `__ + """ + + @staticmethod + def get_split_value(available_splitting_values): + split_value = None + if available_splitting_values.size > 1 and not np.all( + available_splitting_values == available_splitting_values[0] + ): + unique_values = np.unique(available_splitting_values) + while True: + sample = np.random.randint(0, 2, size=len(unique_values)).astype(bool) + if np.any(sample): + break + split_value = unique_values[sample] + return split_value + + @staticmethod + def divide(available_splitting_values, split_value): + return np.isin(available_splitting_values, split_value) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 2c16a7b..07d243a 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -13,11 +13,12 @@ # limitations under the License. from functools import lru_cache -from typing import Dict, Generator, List, Optional +from typing import Dict, Generator, List, Optional, Tuple, Union import numpy as np import numpy.typing as npt from pytensor import config +from .split_rules import SplitRule class Node: @@ -39,7 +40,7 @@ def __init__( nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[List[float]] = None, + linear_params: Optional[List[npt.NDArray[np.float_]]] = None, ) -> None: self.value = value self.nvalue = nvalue @@ -54,7 +55,7 @@ def new_leaf_node( nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[List[float]] = None, + linear_params: Optional[List[npt.NDArray[np.float_]]] = None, ) -> "Node": return cls( value=value, @@ -114,20 +115,18 @@ class Tree: idx_leaf_nodes : List with the index of the leaf nodes of the tree. """ - __slots__ = ( - "tree_structure", - "output", - "idx_leaf_nodes", - ) + __slots__ = ("tree_structure", "output", "idx_leaf_nodes", "split_rules") def __init__( self, tree_structure: Dict[int, Node], output: npt.NDArray[np.float_], + split_rules: List[SplitRule], idx_leaf_nodes: Optional[List[int]] = None, ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes + self.split_rules = split_rules self.output = output @classmethod @@ -137,6 +136,7 @@ def new_tree( idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, + split_rules: List[SplitRule], ) -> "Tree": return cls( tree_structure={ @@ -148,6 +148,7 @@ def new_tree( }, idx_leaf_nodes=[0], output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(), + split_rules=split_rules, ) def __getitem__(self, index) -> Node: @@ -168,7 +169,12 @@ def copy(self) -> "Tree": for k, v in self.tree_structure.items() } idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None - return Tree(tree_structure=tree, idx_leaf_nodes=idx_leaf_nodes, output=self.output) + return Tree( + tree_structure=tree, + idx_leaf_nodes=idx_leaf_nodes, + output=self.output, + split_rules=self.split_rules, + ) def get_node(self, index: int) -> Node: return self.tree_structure[index] @@ -202,7 +208,12 @@ def trim(self) -> "Tree": ) for k, v in self.tree_structure.items() } - return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1])) + return Tree( + tree_structure=tree, + idx_leaf_nodes=None, + output=np.array([-1]), + split_rules=self.split_rules, + ) def get_split_variables(self) -> Generator[int, None, None]: for node in self.tree_structure.values(): @@ -241,21 +252,22 @@ def predict( """ if excluded is None: excluded = [] - return self._traverse_tree(x=x, excluded=excluded, shape=shape) + + return self._traverse_tree(X=x, excluded=excluded, shape=shape) def _traverse_tree( self, - x: npt.NDArray[np.float_], + X: npt.NDArray[np.float_], excluded: Optional[List[int]] = None, - shape: int = 1, + shape: Union[int, Tuple[int, ...]] = 1, ) -> npt.NDArray[np.float_]: """ Traverse the tree starting from the root node given an (un)observed point. Parameters ---------- - x : npt.NDArray[np.float_] - (Un)observed point + X : npt.NDArray[np.float_] + (Un)observed point(s) node_index : int Index of the node to start the traversal from split_variable : int @@ -268,35 +280,47 @@ def _traverse_tree( npt.NDArray[np.float_] Leaf node value or mean of leaf node values """ - stack = [(0, 1.0)] # (node_index, weight) initial state - p_d = np.zeros(shape) + + x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1] + + stack = [(0, np.ones(x_shape))] # (node_index, weight) initial state + p_d = ( + np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape) + ) while stack: - node_index, weight = stack.pop() + node_index, weights = stack.pop() node = self.get_node(node_index) if node.is_leaf_node(): params = node.linear_params + nd_dims = (...,) + (None,) * len(x_shape) if params is None: - p_d += weight * node.value + p_d += weights * node.value[nd_dims] else: # this produce nonsensical results - p_d += weight * (params[0] + params[1] * x[node.idx_split_variable]) + p_d += weights * ( + params[0][nd_dims] + params[1][nd_dims] * X[..., node.idx_split_variable] + ) # this produce reasonable result # p_d += weight * node.value.mean() else: + left_node_index, right_node_index = get_idx_left_child( + node_index + ), get_idx_right_child(node_index) if excluded is not None and node.idx_split_variable in excluded: - left_node_index, right_node_index = get_idx_left_child( - node_index - ), get_idx_right_child(node_index) prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue - stack.append((left_node_index, weight * prop_nvalue_left)) - stack.append((right_node_index, weight * (1 - prop_nvalue_left))) + stack.append((left_node_index, weights * prop_nvalue_left)) + stack.append((right_node_index, weights * (1 - prop_nvalue_left))) else: - next_node = ( - get_idx_left_child(node_index) - if x[node.idx_split_variable] <= node.value - else get_idx_right_child(node_index) + to_left = ( + self.split_rules[node.idx_split_variable] + .divide(X[..., node.idx_split_variable], node.value) + .astype("float") ) - stack.append((next_node, weight)) + stack.append((left_node_index, weights * to_left)) + stack.append((right_node_index, weights * (1 - to_left))) + + if len(X.shape) == 1: + p_d = p_d[..., 0] return p_d diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 16fe17a..d6a16eb 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -64,7 +64,7 @@ def _sample_posterior( for ind, p in enumerate(pred): for tree in stacked_trees[idx[ind]]: - p += np.vstack([tree.predict(x=x, excluded=excluded, shape=shape) for x in X]) + p += tree.predict(x=X, excluded=excluded, shape=shape).T pred.reshape((*size_iter, shape, -1)) return pred diff --git a/tests/test_split_rules.py b/tests/test_split_rules.py new file mode 100644 index 0000000..e84810d --- /dev/null +++ b/tests/test_split_rules.py @@ -0,0 +1,41 @@ +import numpy as np + +from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule +import pytest + + +@pytest.mark.parametrize( + argnames="Rule", + argvalues=[ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule], + ids=["continuous", "one_hot", "subset"], +) +def test_split_rule(Rule): + + # Should return None if only one available value to pick from + assert Rule.get_split_value(np.zeros(1)) is None + + # get_split should return a value divide can use + available_values = np.arange(10).astype(float) + sv = Rule.get_split_value(available_values) + left = Rule.divide(available_values, sv) + + # divide should return a boolean numpy array + # This de facto ensures it is a binary split + assert len(left) == len(available_values) + assert left.dtype == "bool" + + # divide should be deterministic + left_repeated = Rule.divide(available_values, sv) + assert (left == left_repeated).all() + + # Most elements should have a chance to go either direction + # NB! This is not 100% necessary, but is a good proxy + probs = np.array( + [ + Rule.divide(available_values, Rule.get_split_value(available_values)) + for _ in range(10000) + ] + ).mean(axis=0) + + assert (probs > 0.01).sum() >= len(available_values) - 1 + assert (probs < 0.99).sum() >= len(available_values) - 1