Skip to content

Linear nodes attempt (v2) #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fe38d53
linear nodes init
Apr 3, 2023
ce3ed63
fix lint
Apr 13, 2023
911ba73
Update pymc_bart/pgbart.py
juanitorduz May 11, 2023
613d0ec
address comments
juanitorduz May 11, 2023
e9c67c9
fix arg error
juanitorduz May 11, 2023
e5c8bd3
add bracket
juanitorduz May 11, 2023
64dcd07
clean func
juanitorduz May 12, 2023
201b85e
fix rm init param linear_params
juanitorduz May 12, 2023
7ff4a21
clean up
juanitorduz May 12, 2023
6c327df
minor improvements
juanitorduz May 12, 2023
83abae6
fix mean denominator
juanitorduz May 12, 2023
796c9a6
fix shape
juanitorduz May 12, 2023
f4e7074
add tests
juanitorduz May 12, 2023
5b7f67c
fix working implementation
juanitorduz May 13, 2023
de714f9
try fix shapre vect
juanitorduz May 13, 2023
7ec226e
undo last commit
juanitorduz May 13, 2023
89fd59e
Update pymc_bart/pgbart.py
juanitorduz May 16, 2023
996596f
address comments
juanitorduz May 16, 2023
0c0209d
change default response
juanitorduz May 16, 2023
f4fcce1
better unpacking
juanitorduz May 16, 2023
db767ce
some type hints
juanitorduz May 16, 2023
6d46784
fix trim method
juanitorduz May 18, 2023
c2de9fb
handle zero variance fast linear fit
juanitorduz May 23, 2023
8c6963d
simplify condition
juanitorduz May 23, 2023
f2c546a
linear response in more than 1d
aloctavodia May 25, 2023
779af0d
make node value type an array so that is compatible with vectorization
juanitorduz May 29, 2023
701d8e1
rm .item() from fast linear mdoel output
juanitorduz May 29, 2023
c242628
improve missing values filters
juanitorduz May 29, 2023
c08f86b
improve missing values filters
juanitorduz May 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ good-names=i,
rv,
new_X,
new_y,
a,
b,
n,


# Include a hint for the correct naming format with invalid-name
Expand Down
7 changes: 6 additions & 1 deletion pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None,
else:
return np.full(cls.Y.shape[0], cls.Y.mean())
else:
return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze().T
return _sample_posterior(cls.all_trees, cls.X, cls.m, rng=rng).squeeze().T


bart = BARTRV()
Expand All @@ -72,6 +72,9 @@ class BART(Distribution):
The response vector.
m : int
Number of trees
response : str
How the leaf_node values are computed. Available options are ``constant``, ``linear`` or
``mix``. Defaults to ``constant``.
alpha : float
Control the prior probability over the depth of the trees. Even when it can takes values in
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
Expand All @@ -88,6 +91,7 @@ def __new__(
Y: TensorLike,
m: int = 50,
alpha: float = 0.25,
response: str = "constant",
split_prior: Optional[List[float]] = None,
**kwargs,
):
Expand All @@ -110,6 +114,7 @@ def __new__(
X=X,
Y=Y,
m=m,
response=response,
alpha=alpha,
split_prior=split_prior,
),
Expand Down
92 changes: 72 additions & 20 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def sample_tree(
missing_data,
sum_trees,
m,
response,
normal,
shape,
) -> bool:
Expand All @@ -73,6 +74,7 @@ def sample_tree(
missing_data,
sum_trees,
m,
response,
normal,
self.kfactor,
shape,
Expand Down Expand Up @@ -131,6 +133,7 @@ def __init__(

self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.response = self.bart.response
shape = initial_values[value_bart.name].shape
self.shape = 1 if len(shape) == 1 else shape[0]

Expand Down Expand Up @@ -160,6 +163,7 @@ def __init__(
num_observations=self.num_observations,
shape=self.shape,
)

self.normal = NormalSampler(mu_std, self.shape)
self.uniform = UniformSampler(0, 1)
self.uniform_kf = UniformSampler(0.33, 0.75, self.shape)
Expand Down Expand Up @@ -209,6 +213,7 @@ def astep(self, _):
self.missing_data,
self.sum_trees,
self.m,
self.response,
self.normal,
self.shape,
):
Expand Down Expand Up @@ -393,6 +398,7 @@ def grow_tree(
missing_data,
sum_trees,
m,
response,
normal,
kfactor,
shape,
Expand All @@ -402,8 +408,10 @@ def grow_tree(

index_selected_predictor = ssv.rvs()
selected_predictor = available_predictors[index_selected_predictor]
available_splitting_values = X[idx_data_points, selected_predictor]
split_value = get_split_value(available_splitting_values, idx_data_points, missing_data)
idx_data_points, available_splitting_values = filter_missing_values(
X[idx_data_points, selected_predictor], idx_data_points, missing_data
)
split_value = get_split_value(available_splitting_values)

if split_value is None:
return None
Expand All @@ -415,18 +423,24 @@ def grow_tree(
get_idx_right_child(index_leaf_node),
)

if response == "mix":
response = "linear" if np.random.random() >= 0.5 else "constant"

for idx in range(2):
idx_data_point = new_idx_data_points[idx]
node_value = draw_leaf_value(
sum_trees[:, idx_data_point],
m,
normal.rvs() * kfactor,
shape,
node_value, linear_params = draw_leaf_value(
y_mu_pred=sum_trees[:, idx_data_point],
x_mu=X[idx_data_point, selected_predictor],
m=m,
norm=normal.rvs() * kfactor,
shape=shape,
response=response,
)

new_node = Node.new_leaf_node(
value=node_value,
idx_data_points=idx_data_point,
linear_params=linear_params,
)
tree.set_node(current_node_children[idx], new_node)

Expand All @@ -440,39 +454,52 @@ def get_new_idx_data_points(available_splitting_values, split_value, idx_data_po
return idx_data_points[split_idx], idx_data_points[~split_idx]


def get_split_value(available_splitting_values, idx_data_points, missing_data):
def filter_missing_values(available_splitting_values, idx_data_points, missing_data):
if missing_data:
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
available_splitting_values = available_splitting_values[
~np.isnan(available_splitting_values)
]
mask = ~np.isnan(available_splitting_values)
idx_data_points = idx_data_points[mask]
available_splitting_values = available_splitting_values[mask]
return idx_data_points, available_splitting_values


def get_split_value(available_splitting_values):
split_value = None
if available_splitting_values.size > 0:
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
split_value = available_splitting_values[idx_selected_splitting_values]

return split_value


@njit
def draw_leaf_value(y_mu_pred, m, norm, shape):
def draw_leaf_value(
y_mu_pred: npt.NDArray[np.float_],
x_mu: npt.NDArray[np.float_],
m: int,
norm: npt.NDArray[np.float_],
shape: int,
response: str,
) -> Tuple[npt.NDArray[np.float_], Optional[npt.NDArray[np.float_]]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean = np.empty(shape)
if y_mu_pred.size == 0:
return np.zeros(shape)
return np.zeros(shape), linear_params

if y_mu_pred.size == 1:
mu_mean = np.full(shape, y_mu_pred.item() / m)
else:
mu_mean = fast_mean(y_mu_pred) / m
if response == "constant":
mu_mean = fast_mean(y_mu_pred) / m
if response == "linear":
y_fit, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred)
mu_mean = y_fit / m

return norm + mu_mean
draw = norm + mu_mean
return draw, linear_params


@njit
def fast_mean(ari):
def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
"""Use Numba to speed up the computation of the mean."""

if ari.ndim == 1:
count = ari.shape[0]
suma = 0
Expand All @@ -488,6 +515,31 @@ def fast_mean(ari):
return res / count


@njit
def fast_linear_fit(
x: npt.NDArray[np.float_], y: npt.NDArray[np.float_]
) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]:
n = len(x)

xbar = np.sum(x) / n
ybar = np.sum(y, axis=1) / n

x_diff = x - xbar
y_diff = y - np.expand_dims(ybar, axis=1)

x_var = np.dot(x_diff, x_diff.T)

if x_var == 0:
b = np.zeros(y.shape[0])
else:
b = np.dot(x_diff, y_diff.T) / x_var

a = ybar - b * xbar

y_fit = np.expand_dims(a, axis=1) + np.expand_dims(b, axis=1) * x
return y_fit.T, [a, b]


def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value).

Expand Down
Loading