Skip to content

BART: add linear response, increase number of trees fitted per step #5044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pymc/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
pred = np.zeros((flatten_size, X_new.shape[0]))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += np.array([tree.predict_out_of_sample(x) for x in X_new])
p += np.array([tree.predict_out_of_sample(x, cls.m) for x in X_new])
return pred.reshape((*size, -1))
else:
return np.full_like(cls.Y, cls.Y.mean())
Expand Down Expand Up @@ -92,6 +92,9 @@ class BART(NoDistribution):
k : float
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
and 3.
response : str
How the leaf_node values are computed. Available options are ``constant``, ``linear`` or
``mix`` (default).
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Expand All @@ -106,6 +109,7 @@ def __new__(
m=50,
alpha=0.25,
k=2,
response="mix",
split_prior=None,
**kwargs,
):
Expand All @@ -125,6 +129,7 @@ def __new__(
m=m,
alpha=alpha,
k=k,
response=response,
split_prior=split_prior,
),
)()
Expand Down
32 changes: 21 additions & 11 deletions pymc/distributions/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,31 @@ def predict_output(self):

return output.astype(aesara.config.floatX)

def predict_out_of_sample(self, x):
def predict_out_of_sample(self, X, m):
"""
Predict output of tree for an unobserved point x.

Parameters
----------
x : numpy array
X : numpy array
Unobserved point
m : int
Number of trees

Returns
-------
float
Value of the leaf value where the unobserved point lies.
"""
leaf_node = self._traverse_tree(x=x, node_index=0)
return leaf_node.value

def _traverse_tree(self, x, node_index=0):
leaf_node, split_variable = self._traverse_tree(X, node_index=0)
if leaf_node.linear_params is None:
return leaf_node.value
else:
x = X[split_variable].item()
y_x = leaf_node.linear_params[0] + leaf_node.linear_params[1] * x
return y_x / m

def _traverse_tree(self, x, node_index=0, split_variable=None):
"""
Traverse the tree starting from a particular node given an unobserved point.

Expand All @@ -125,13 +133,14 @@ def _traverse_tree(self, x, node_index=0):
"""
current_node = self.get_node(node_index)
if isinstance(current_node, SplitNode):
if x[current_node.idx_split_variable] <= current_node.split_value:
split_variable = current_node.idx_split_variable
if x[split_variable] <= current_node.split_value:
left_child = current_node.get_idx_left_child()
current_node = self._traverse_tree(x, left_child)
current_node, _ = self._traverse_tree(x, left_child, split_variable)
else:
right_child = current_node.get_idx_right_child()
current_node = self._traverse_tree(x, right_child)
return current_node
current_node, _ = self._traverse_tree(x, right_child, split_variable)
return current_node, split_variable

def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node):
"""
Expand Down Expand Up @@ -202,7 +211,8 @@ def __init__(self, index, idx_split_variable, split_value):


class LeafNode(BaseNode):
def __init__(self, index, value, idx_data_points):
def __init__(self, index, value, idx_data_points, linear_params=None):
super().__init__(index)
self.value = value
self.idx_data_points = idx_data_points
self.linear_params = linear_params
117 changes: 92 additions & 25 deletions pymc/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def sample_tree_sequential(
missing_data,
sum_trees_output,
mean,
linear_fit,
m,
normal,
mu_std,
response,
):
tree_grew = False
if self.expansion_nodes:
Expand All @@ -73,9 +75,11 @@ def sample_tree_sequential(
missing_data,
sum_trees_output,
mean,
linear_fit,
m,
normal,
mu_std,
response,
)
if tree_grew:
new_indexes = self.tree.idx_leaf_nodes[-2:]
Expand All @@ -97,11 +101,17 @@ class PGBART(ArrayStepShared):
Number of particles for the conditional SMC sampler. Defaults to 10
max_stages : int
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
chunk = int
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees.
batch : int
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
during tuning and 20% after tuning.
model: PyMC Model
Optional model for sampling step. Defaults to None (taken from context).

Note
----
This sampler is inspired by the [Lakshminarayanan2015] Particle Gibbs sampler, but introduces
several changes. The changes will be properly documented soon.

References
----------
.. [Lakshminarayanan2015] Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015),
Expand All @@ -114,7 +124,7 @@ class PGBART(ArrayStepShared):
generates_stats = True
stats_dtypes = [{"variable_inclusion": np.ndarray}]

def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", model=None):
def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None):
_log.warning("BART is experimental. Use with caution.")
model = modelcontext(model)
initial_values = model.initial_point
Expand All @@ -125,6 +135,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
self.m = self.bart.m
self.alpha = self.bart.alpha
self.k = self.bart.k
self.response = self.bart.response
self.split_prior = self.bart.split_prior
if self.split_prior is None:
self.split_prior = np.ones(self.X.shape[1])
Expand All @@ -149,6 +160,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
idx_data_points=np.arange(self.num_observations, dtype="int32"),
)
self.mean = fast_mean()
self.linear_fit = fast_linear_fit()

self.normal = NormalSampler()
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
self.ssv = SampleSplittingVariable(self.split_prior)
Expand All @@ -157,10 +170,10 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
self.idx = 0
self.iter = 0
self.sum_trees = []
self.chunk = chunk
self.batch = batch

if self.chunk == "auto":
self.chunk = max(1, int(self.m * 0.1))
if self.batch == "auto":
self.batch = max(1, int(self.m * 0.1))
self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
self.len_indices = len(self.indices)
Expand Down Expand Up @@ -190,7 +203,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if self.idx == self.m:
self.idx = 0

for tree_id in range(self.idx, self.idx + self.chunk):
for tree_id in range(self.idx, self.idx + self.batch):
if tree_id >= self.m:
break
# Generate an initial set of SMC particles
Expand All @@ -213,9 +226,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.missing_data,
sum_trees_output,
self.mean,
self.linear_fit,
self.m,
self.normal,
self.mu_std,
self.response,
)
if tree_grew:
self.update_weight(p)
Expand Down Expand Up @@ -251,6 +266,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.split_prior[index] += 1
self.ssv = SampleSplittingVariable(self.split_prior)
else:
self.batch = max(1, int(self.m * 0.2))
self.iter += 1
self.sum_trees.append(new_tree)
if not self.iter % self.m:
Expand Down Expand Up @@ -389,16 +405,20 @@ def grow_tree(
missing_data,
sum_trees_output,
mean,
linear_fit,
m,
normal,
mu_std,
response,
):
current_node = tree.get_node(index_leaf_node)
idx_data_points = current_node.idx_data_points

index_selected_predictor = ssv.rvs()
selected_predictor = available_predictors[index_selected_predictor]
available_splitting_values = X[current_node.idx_data_points, selected_predictor]
available_splitting_values = X[idx_data_points, selected_predictor]
if missing_data:
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
available_splitting_values = available_splitting_values[
~np.isnan(available_splitting_values)
]
Expand All @@ -407,58 +427,82 @@ def grow_tree(
return False, None

idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
selected_splitting_rule = available_splitting_values[idx_selected_splitting_values]
split_value = available_splitting_values[idx_selected_splitting_values]
new_split_node = SplitNode(
index=index_leaf_node,
idx_split_variable=selected_predictor,
split_value=selected_splitting_rule,
split_value=split_value,
)

left_node_idx_data_points, right_node_idx_data_points = get_new_idx_data_points(
new_split_node, current_node.idx_data_points, X
split_value, idx_data_points, selected_predictor, X
)

left_node_value = draw_leaf_value(
sum_trees_output[left_node_idx_data_points], mean, m, normal, mu_std
if response == "mix":
response = "linear" if np.random.random() >= 0.5 else "constant"

left_node_value, left_node_linear_params = draw_leaf_value(
sum_trees_output[left_node_idx_data_points],
X[left_node_idx_data_points, selected_predictor],
mean,
linear_fit,
m,
normal,
mu_std,
response,
)
right_node_value = draw_leaf_value(
sum_trees_output[right_node_idx_data_points], mean, m, normal, mu_std
right_node_value, right_node_linear_params = draw_leaf_value(
sum_trees_output[right_node_idx_data_points],
X[right_node_idx_data_points, selected_predictor],
mean,
linear_fit,
m,
normal,
mu_std,
response,
)

new_left_node = LeafNode(
index=current_node.get_idx_left_child(),
value=left_node_value,
idx_data_points=left_node_idx_data_points,
linear_params=left_node_linear_params,
)
new_right_node = LeafNode(
index=current_node.get_idx_right_child(),
value=right_node_value,
idx_data_points=right_node_idx_data_points,
linear_params=right_node_linear_params,
)
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)

return True, index_selected_predictor


def get_new_idx_data_points(current_split_node, idx_data_points, X):
idx_split_variable = current_split_node.idx_split_variable
split_value = current_split_node.split_value
def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):

left_idx = X[idx_data_points, idx_split_variable] <= split_value
left_idx = X[idx_data_points, selected_predictor] <= split_value
left_node_idx_data_points = idx_data_points[left_idx]
right_node_idx_data_points = idx_data_points[~left_idx]

return left_node_idx_data_points, right_node_idx_data_points


def draw_leaf_value(sum_trees_output_idx, mean, m, normal, mu_std):
def draw_leaf_value(Y_mu_pred, X_mu, mean, linear_fit, m, normal, mu_std, response):
"""Draw Gaussian distributed leaf values"""
if sum_trees_output_idx.size == 0:
return 0
linear_params = None
if Y_mu_pred.size == 0:
return 0, linear_params
elif Y_mu_pred.size == 1:
mu_mean = Y_mu_pred.item() / m
else:
mu_mean = mean(sum_trees_output_idx) / m
draw = normal.random() * mu_std + mu_mean
return draw
if response == "constant":
mu_mean = mean(Y_mu_pred) / m
elif response == "linear":
Y_fit, linear_params = linear_fit(X_mu, Y_mu_pred)
mu_mean = Y_fit / m
draw = normal.random() * mu_std + mu_mean
return draw, linear_params


def fast_mean():
Expand All @@ -479,6 +523,29 @@ def mean(a):
return mean


def fast_linear_fit():
"""If available use Numba to speed up the computation of the linear fit"""

def linear_fit(X, Y):

n = len(Y)
xbar = np.sum(X) / n
ybar = np.sum(Y) / n

b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)
a = ybar - b * xbar

Y_fit = a + b * X
return Y_fit, (a, b)

try:
from numba import jit

return jit(linear_fit)
except ImportError:
return linear_fit


def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value).

Expand Down
1 change: 0 additions & 1 deletion pymc/tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def test_bart_random():
rng = RandomState(12345)
pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10])

assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
assert pred_all.shape == (2, 50)
assert pred_first.shape == (10,)

Expand Down