@@ -189,7 +189,9 @@ def _predict(self) -> npt.NDArray[np.float_]:
189
189
output [leaf_node .idx_data_points ] = leaf_node .value
190
190
return output .T
191
191
192
- def predict (self , x : npt .NDArray [np .float_ ], excluded : Optional [List [int ]] = None ) -> float :
192
+ def predict (
193
+ self , x : npt .NDArray [np .float_ ], excluded : Optional [List [int ]] = None
194
+ ) -> npt .NDArray [np .float_ ]:
193
195
"""
194
196
Predict output of tree for an (un)observed point x.
195
197
@@ -202,7 +204,7 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non
202
204
203
205
Returns
204
206
-------
205
- float
207
+ npt.NDArray[np.float_]
206
208
Value of the leaf value where the unobserved point lies.
207
209
"""
208
210
if excluded is None :
@@ -211,7 +213,7 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non
211
213
212
214
def _traverse_tree (
213
215
self , x : npt .NDArray [np .float_ ], node_index : int , excluded : Optional [List [int ]] = None
214
- ) -> float :
216
+ ) -> npt . NDArray [ np . float_ ] :
215
217
"""
216
218
Traverse the tree starting from a particular node given an unobserved point.
217
219
@@ -226,11 +228,12 @@ def _traverse_tree(
226
228
227
229
Returns
228
230
-------
229
- Leaf node value or mean of leaf node values
231
+ npt.NDArray[np.float_]
232
+ Leaf node value or mean of leaf node values
230
233
"""
231
234
current_node : Node = self .get_node (node_index )
232
235
if current_node .is_leaf_node ():
233
- return current_node .value
236
+ return np . array ( current_node .value )
234
237
235
238
if excluded is not None and current_node .idx_split_variable in excluded :
236
239
leaf_values : List [float ] = []
0 commit comments