@@ -110,7 +110,7 @@ def __init__(
110
110
self ,
111
111
tree_structure : Dict [int , Node ],
112
112
idx_leaf_nodes : Optional [npt .NDArray [np .int_ ]],
113
- output : Optional [ npt .NDArray [np .float_ ] ],
113
+ output : npt .NDArray [np .float_ ],
114
114
) -> None :
115
115
self .tree_structure = tree_structure
116
116
self .idx_leaf_nodes = idx_leaf_nodes
@@ -173,24 +173,21 @@ def trim(self) -> "Tree":
173
173
k : Node (v .index , v .value , None , v .idx_split_variable )
174
174
for k , v in self .tree_structure .items ()
175
175
}
176
- return Tree (tree , None , None )
176
+ return Tree (tree_structure = tree , idx_leaf_nodes = None , output = np . array ([ - 1 ]) )
177
177
178
178
def get_split_variables (self ) -> Generator [int , None , None ]:
179
179
for node in self .tree_structure .values ():
180
180
if node .is_split_node ():
181
181
yield node .idx_split_variable
182
182
183
- def _predict (self ) -> Optional [ npt .NDArray [np .float_ ] ]:
183
+ def _predict (self ) -> npt .NDArray [np .float_ ]:
184
184
output = self .output
185
185
186
- if output is None :
187
- return None
188
-
189
186
if self .idx_leaf_nodes is not None :
190
187
for node_index in self .idx_leaf_nodes :
191
188
leaf_node = self .get_node (node_index )
192
189
output [leaf_node .idx_data_points ] = leaf_node .value
193
- return output .T if output is not None else None
190
+ return output .T
194
191
195
192
def predict (self , x : npt .NDArray [np .float_ ], excluded : Optional [List [int ]] = None ) -> float :
196
193
"""
@@ -244,7 +241,7 @@ def _traverse_tree(
244
241
next_node = current_node .get_idx_left_child ()
245
242
else :
246
243
next_node = current_node .get_idx_right_child ()
247
- return self ._traverse_tree (x , next_node , excluded )
244
+ return self ._traverse_tree (x = x , node_index = next_node , excluded = excluded )
248
245
249
246
def _traverse_leaf_values (self , leaf_values : List [float ], node_index : int ) -> None :
250
247
"""
0 commit comments