Skip to content

Commit b4043ba

Browse files
howsiyuSi Yu How
and
Si Yu How
authored
We don't need to store index at each node. (#80)
* We don't need to store index at each node. * Fast integer log2 --------- Co-authored-by: Si Yu How <siyuhow@siyu.colet10>
1 parent be9dbf0 commit b4043ba

File tree

3 files changed

+37
-49
lines changed

3 files changed

+37
-49
lines changed

pymc_bart/pgbart.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor.tensor.var import Variable
2626

2727
from pymc_bart.bart import BARTRV
28-
from pymc_bart.tree import Node, Tree, get_depth
28+
from pymc_bart.tree import Node, Tree, get_idx_left_child, get_idx_right_child, get_depth
2929

3030

3131
class ParticleTree:
@@ -411,11 +411,10 @@ def grow_tree(
411411
available_splitting_values, split_value, idx_data_points
412412
)
413413
current_node_children = (
414-
current_node.get_idx_left_child(),
415-
current_node.get_idx_right_child(),
414+
get_idx_left_child(index_leaf_node),
415+
get_idx_right_child(index_leaf_node),
416416
)
417417

418-
new_nodes = np.empty(2, dtype=object)
419418
for idx in range(2):
420419
idx_data_point = new_idx_data_points[idx]
421420
node_value = draw_leaf_value(
@@ -426,17 +425,13 @@ def grow_tree(
426425
)
427426

428427
new_node = Node.new_leaf_node(
429-
index=current_node_children[idx],
430428
value=node_value,
431429
idx_data_points=idx_data_point,
432430
)
433-
new_nodes[idx] = new_node
431+
tree.set_node(current_node_children[idx], new_node)
434432

435433
tree.grow_leaf_node(current_node, selected_predictor, split_value, index_leaf_node)
436-
tree.set_node(new_nodes[0].index, new_nodes[0])
437-
tree.set_node(new_nodes[1].index, new_nodes[1])
438-
439-
return [new_nodes[0].index, new_nodes[1].index]
434+
return current_node_children
440435

441436

442437
@njit

pymc_bart/tree.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
1615
from functools import lru_cache
1716
from typing import Dict, Generator, List, Optional
1817

@@ -26,41 +25,30 @@ class Node:
2625
2726
Attributes
2827
----------
29-
index : int
3028
value : float
3129
idx_data_points : Optional[npt.NDArray[np.int_]]
3230
idx_split_variable : Optional[npt.NDArray[np.int_]]
3331
"""
3432

35-
__slots__ = "index", "value", "idx_split_variable", "idx_data_points"
33+
__slots__ = "value", "idx_split_variable", "idx_data_points"
3634

3735
def __init__(
3836
self,
39-
index: int,
4037
value: float = -1.0,
4138
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
4239
idx_split_variable: int = -1,
4340
) -> None:
44-
self.index = index
4541
self.value = value
4642
self.idx_data_points = idx_data_points
4743
self.idx_split_variable = idx_split_variable
4844

4945
@classmethod
50-
def new_leaf_node(
51-
cls, index: int, value: float, idx_data_points: Optional[npt.NDArray[np.int_]]
52-
) -> "Node":
53-
return cls(index, value=value, idx_data_points=idx_data_points)
46+
def new_leaf_node(cls, value: float, idx_data_points: Optional[npt.NDArray[np.int_]]) -> "Node":
47+
return cls(value=value, idx_data_points=idx_data_points)
5448

5549
@classmethod
56-
def new_split_node(cls, index: int, split_value: float, idx_split_variable: int) -> "Node":
57-
return cls(index=index, value=split_value, idx_split_variable=idx_split_variable)
58-
59-
def get_idx_left_child(self) -> int:
60-
return self.index * 2 + 1
61-
62-
def get_idx_right_child(self) -> int:
63-
return self.index * 2 + 2
50+
def new_split_node(cls, split_value: float, idx_split_variable: int) -> "Node":
51+
return cls(value=split_value, idx_split_variable=idx_split_variable)
6452

6553
def is_split_node(self) -> bool:
6654
return self.idx_split_variable >= 0
@@ -69,9 +57,17 @@ def is_leaf_node(self) -> bool:
6957
return not self.is_split_node()
7058

7159

60+
def get_idx_left_child(index) -> int:
61+
return index * 2 + 1
62+
63+
64+
def get_idx_right_child(index) -> int:
65+
return index * 2 + 2
66+
67+
7268
@lru_cache
7369
def get_depth(index: int) -> int:
74-
return math.floor(math.log2(index + 1))
70+
return (index + 1).bit_length() - 1
7571

7672

7773
class Tree:
@@ -126,9 +122,7 @@ def new_tree(
126122
) -> "Tree":
127123
return cls(
128124
tree_structure={
129-
0: Node.new_leaf_node(
130-
index=0, value=leaf_node_value, idx_data_points=idx_data_points
131-
)
125+
0: Node.new_leaf_node(value=leaf_node_value, idx_data_points=idx_data_points)
132126
},
133127
idx_leaf_nodes=[0],
134128
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
@@ -142,7 +136,7 @@ def __setitem__(self, index, node) -> None:
142136

143137
def copy(self) -> "Tree":
144138
tree: Dict[int, Node] = {
145-
k: Node(v.index, v.value, v.idx_data_points, v.idx_split_variable)
139+
k: Node(v.value, v.idx_data_points, v.idx_split_variable)
146140
for k, v in self.tree_structure.items()
147141
}
148142
idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None
@@ -167,8 +161,7 @@ def grow_leaf_node(
167161

168162
def trim(self) -> "Tree":
169163
tree: Dict[int, Node] = {
170-
k: Node(v.index, v.value, None, v.idx_split_variable)
171-
for k, v in self.tree_structure.items()
164+
k: Node(v.value, None, v.idx_split_variable) for k, v in self.tree_structure.items()
172165
}
173166
return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1]))
174167

@@ -241,9 +234,9 @@ def _traverse_tree(
241234
return np.mean(leaf_values, axis=0)
242235

243236
if x[current_node.idx_split_variable] <= current_node.value:
244-
next_node = current_node.get_idx_left_child()
237+
next_node = get_idx_left_child(node_index)
245238
else:
246-
next_node = current_node.get_idx_right_child()
239+
next_node = get_idx_right_child(node_index)
247240
return self._traverse_tree(x=x, node_index=next_node, excluded=excluded)
248241

249242
def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None:
@@ -259,5 +252,5 @@ def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> No
259252
if node.is_leaf_node():
260253
leaf_values.append(node.value)
261254
else:
262-
self._traverse_leaf_values(leaf_values, node.get_idx_left_child())
263-
self._traverse_leaf_values(leaf_values, node.get_idx_right_child())
255+
self._traverse_leaf_values(leaf_values, get_idx_left_child(node_index))
256+
self._traverse_leaf_values(leaf_values, get_idx_right_child(node_index))

tests/test_tree.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
import numpy as np
22

3-
from pymc_bart.tree import Node, get_depth
3+
from pymc_bart.tree import Node, get_idx_left_child, get_idx_right_child, get_depth
44

55

66
def test_split_node():
7-
split_node = Node.new_split_node(index=5, idx_split_variable=2, split_value=3.0)
8-
assert split_node.index == 5
9-
assert get_depth(split_node.index) == 2
7+
index = 5
8+
split_node = Node.new_split_node(idx_split_variable=2, split_value=3.0)
9+
assert get_depth(index) == 2
1010
assert split_node.value == 3.0
1111
assert split_node.idx_split_variable == 2
1212
assert split_node.idx_data_points is None
13-
assert split_node.get_idx_left_child() == 11
14-
assert split_node.get_idx_right_child() == 12
13+
assert get_idx_left_child(index) == 11
14+
assert get_idx_right_child(index) == 12
1515
assert split_node.is_split_node() is True
1616
assert split_node.is_leaf_node() is False
1717

1818

1919
def test_leaf_node():
20-
leaf_node = Node.new_leaf_node(index=5, value=3.14, idx_data_points=[1, 2, 3])
21-
assert leaf_node.index == 5
22-
assert get_depth(leaf_node.index) == 2
20+
index = 5
21+
leaf_node = Node.new_leaf_node(value=3.14, idx_data_points=[1, 2, 3])
22+
assert get_depth(index) == 2
2323
assert leaf_node.value == 3.14
2424
assert leaf_node.idx_split_variable == -1
2525
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
26-
assert leaf_node.get_idx_left_child() == 11
27-
assert leaf_node.get_idx_right_child() == 12
26+
assert get_idx_left_child(index) == 11
27+
assert get_idx_right_child(index) == 12
2828
assert leaf_node.is_split_node() is False
2929
assert leaf_node.is_leaf_node() is True

0 commit comments

Comments
 (0)