Skip to content

Commit 8d4fc23

Browse files
author
juanitorduz
committed
remove outout None type
1 parent 071c8cc commit 8d4fc23

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

pymc_bart/tree.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
self,
111111
tree_structure: Dict[int, Node],
112112
idx_leaf_nodes: Optional[npt.NDArray[np.int_]],
113-
output: Optional[npt.NDArray[np.float_]],
113+
output: npt.NDArray[np.float_],
114114
) -> None:
115115
self.tree_structure = tree_structure
116116
self.idx_leaf_nodes = idx_leaf_nodes
@@ -173,24 +173,21 @@ def trim(self) -> "Tree":
173173
k: Node(v.index, v.value, None, v.idx_split_variable)
174174
for k, v in self.tree_structure.items()
175175
}
176-
return Tree(tree, None, None)
176+
return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1]))
177177

178178
def get_split_variables(self) -> Generator[int, None, None]:
179179
for node in self.tree_structure.values():
180180
if node.is_split_node():
181181
yield node.idx_split_variable
182182

183-
def _predict(self) -> Optional[npt.NDArray[np.float_]]:
183+
def _predict(self) -> npt.NDArray[np.float_]:
184184
output = self.output
185185

186-
if output is None:
187-
return None
188-
189186
if self.idx_leaf_nodes is not None:
190187
for node_index in self.idx_leaf_nodes:
191188
leaf_node = self.get_node(node_index)
192189
output[leaf_node.idx_data_points] = leaf_node.value
193-
return output.T if output is not None else None
190+
return output.T
194191

195192
def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None) -> float:
196193
"""
@@ -244,7 +241,7 @@ def _traverse_tree(
244241
next_node = current_node.get_idx_left_child()
245242
else:
246243
next_node = current_node.get_idx_right_child()
247-
return self._traverse_tree(x, next_node, excluded)
244+
return self._traverse_tree(x=x, node_index=next_node, excluded=excluded)
248245

249246
def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None:
250247
"""

0 commit comments

Comments
 (0)