12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import math
16
15
from functools import lru_cache
17
16
from typing import Dict , Generator , List , Optional
18
17
@@ -26,41 +25,30 @@ class Node:
26
25
27
26
Attributes
28
27
----------
29
- index : int
30
28
value : float
31
29
idx_data_points : Optional[npt.NDArray[np.int_]]
32
30
idx_split_variable : Optional[npt.NDArray[np.int_]]
33
31
"""
34
32
35
- __slots__ = "index" , " value" , "idx_split_variable" , "idx_data_points"
33
+ __slots__ = "value" , "idx_split_variable" , "idx_data_points"
36
34
37
35
def __init__ (
38
36
self ,
39
- index : int ,
40
37
value : float = - 1.0 ,
41
38
idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
42
39
idx_split_variable : int = - 1 ,
43
40
) -> None :
44
- self .index = index
45
41
self .value = value
46
42
self .idx_data_points = idx_data_points
47
43
self .idx_split_variable = idx_split_variable
48
44
49
45
@classmethod
50
- def new_leaf_node (
51
- cls , index : int , value : float , idx_data_points : Optional [npt .NDArray [np .int_ ]]
52
- ) -> "Node" :
53
- return cls (index , value = value , idx_data_points = idx_data_points )
46
+ def new_leaf_node (cls , value : float , idx_data_points : Optional [npt .NDArray [np .int_ ]]) -> "Node" :
47
+ return cls (value = value , idx_data_points = idx_data_points )
54
48
55
49
@classmethod
56
- def new_split_node (cls , index : int , split_value : float , idx_split_variable : int ) -> "Node" :
57
- return cls (index = index , value = split_value , idx_split_variable = idx_split_variable )
58
-
59
- def get_idx_left_child (self ) -> int :
60
- return self .index * 2 + 1
61
-
62
- def get_idx_right_child (self ) -> int :
63
- return self .index * 2 + 2
50
+ def new_split_node (cls , split_value : float , idx_split_variable : int ) -> "Node" :
51
+ return cls (value = split_value , idx_split_variable = idx_split_variable )
64
52
65
53
def is_split_node (self ) -> bool :
66
54
return self .idx_split_variable >= 0
@@ -69,9 +57,17 @@ def is_leaf_node(self) -> bool:
69
57
return not self .is_split_node ()
70
58
71
59
60
+ def get_idx_left_child (index ) -> int :
61
+ return index * 2 + 1
62
+
63
+
64
+ def get_idx_right_child (index ) -> int :
65
+ return index * 2 + 2
66
+
67
+
72
68
@lru_cache
73
69
def get_depth (index : int ) -> int :
74
- return math . floor ( math . log2 ( index + 1 ))
70
+ return ( index + 1 ). bit_length () - 1
75
71
76
72
77
73
class Tree :
@@ -126,9 +122,7 @@ def new_tree(
126
122
) -> "Tree" :
127
123
return cls (
128
124
tree_structure = {
129
- 0 : Node .new_leaf_node (
130
- index = 0 , value = leaf_node_value , idx_data_points = idx_data_points
131
- )
125
+ 0 : Node .new_leaf_node (value = leaf_node_value , idx_data_points = idx_data_points )
132
126
},
133
127
idx_leaf_nodes = [0 ],
134
128
output = np .zeros ((num_observations , shape )).astype (config .floatX ).squeeze (),
@@ -142,7 +136,7 @@ def __setitem__(self, index, node) -> None:
142
136
143
137
def copy (self ) -> "Tree" :
144
138
tree : Dict [int , Node ] = {
145
- k : Node (v .index , v . value , v .idx_data_points , v .idx_split_variable )
139
+ k : Node (v .value , v .idx_data_points , v .idx_split_variable )
146
140
for k , v in self .tree_structure .items ()
147
141
}
148
142
idx_leaf_nodes = self .idx_leaf_nodes .copy () if self .idx_leaf_nodes is not None else None
@@ -167,8 +161,7 @@ def grow_leaf_node(
167
161
168
162
def trim (self ) -> "Tree" :
169
163
tree : Dict [int , Node ] = {
170
- k : Node (v .index , v .value , None , v .idx_split_variable )
171
- for k , v in self .tree_structure .items ()
164
+ k : Node (v .value , None , v .idx_split_variable ) for k , v in self .tree_structure .items ()
172
165
}
173
166
return Tree (tree_structure = tree , idx_leaf_nodes = None , output = np .array ([- 1 ]))
174
167
@@ -241,9 +234,9 @@ def _traverse_tree(
241
234
return np .mean (leaf_values , axis = 0 )
242
235
243
236
if x [current_node .idx_split_variable ] <= current_node .value :
244
- next_node = current_node . get_idx_left_child ()
237
+ next_node = get_idx_left_child (node_index )
245
238
else :
246
- next_node = current_node . get_idx_right_child ()
239
+ next_node = get_idx_right_child (node_index )
247
240
return self ._traverse_tree (x = x , node_index = next_node , excluded = excluded )
248
241
249
242
def _traverse_leaf_values (self , leaf_values : List [float ], node_index : int ) -> None :
@@ -259,5 +252,5 @@ def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> No
259
252
if node .is_leaf_node ():
260
253
leaf_values .append (node .value )
261
254
else :
262
- self ._traverse_leaf_values (leaf_values , node . get_idx_left_child ())
263
- self ._traverse_leaf_values (leaf_values , node . get_idx_right_child ())
255
+ self ._traverse_leaf_values (leaf_values , get_idx_left_child (node_index ))
256
+ self ._traverse_leaf_values (leaf_values , get_idx_right_child (node_index ))
0 commit comments