Skip to content

Commit 2a252a1

Browse files
author
juanitorduz
committed
fix predict output type
1 parent 8d4fc23 commit 2a252a1

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

pymc_bart/tree.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def _predict(self) -> npt.NDArray[np.float_]:
189189
output[leaf_node.idx_data_points] = leaf_node.value
190190
return output.T
191191

192-
def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None) -> float:
192+
def predict(
193+
self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None
194+
) -> npt.NDArray[np.float_]:
193195
"""
194196
Predict output of tree for an (un)observed point x.
195197
@@ -202,7 +204,7 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non
202204
203205
Returns
204206
-------
205-
float
207+
npt.NDArray[np.float_]
206208
Value of the leaf value where the unobserved point lies.
207209
"""
208210
if excluded is None:
@@ -211,7 +213,7 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non
211213

212214
def _traverse_tree(
213215
self, x: npt.NDArray[np.float_], node_index: int, excluded: Optional[List[int]] = None
214-
) -> float:
216+
) -> npt.NDArray[np.float_]:
215217
"""
216218
Traverse the tree starting from a particular node given an unobserved point.
217219
@@ -226,11 +228,12 @@ def _traverse_tree(
226228
227229
Returns
228230
-------
229-
Leaf node value or mean of leaf node values
231+
npt.NDArray[np.float_]
232+
Leaf node value or mean of leaf node values
230233
"""
231234
current_node: Node = self.get_node(node_index)
232235
if current_node.is_leaf_node():
233-
return current_node.value
236+
return np.array(current_node.value)
234237

235238
if excluded is not None and current_node.idx_split_variable in excluded:
236239
leaf_values: List[float] = []

0 commit comments

Comments
 (0)