Skip to content

Commit d228be3

Browse files
committed
Fix mypy annotations
1 parent 78439e6 commit d228be3

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

pymc_bart/bart.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,6 @@ def __new__(
125125
if split_prior is None:
126126
split_prior = []
127127

128-
if split_rules is None:
129-
split_rules = []
130-
131128
bart_op = type(
132129
f"BART_{name}",
133130
(BARTRV,),

pymc_bart/tree.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
from functools import lru_cache
16-
from typing import Dict, Generator, List, Optional
16+
from typing import Dict, Generator, List, Optional, Tuple, Union
1717

1818
import numpy as np
1919
import numpy.typing as npt
2020
from pytensor import config
21+
from .split_rules import SplitRule
2122

2223

2324
class Node:
@@ -39,7 +40,7 @@ def __init__(
3940
nvalue: int = 0,
4041
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
4142
idx_split_variable: int = -1,
42-
linear_params: Optional[List[float]] = None,
43+
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
4344
) -> None:
4445
self.value = value
4546
self.nvalue = nvalue
@@ -54,7 +55,7 @@ def new_leaf_node(
5455
nvalue: int = 0,
5556
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
5657
idx_split_variable: int = -1,
57-
linear_params: Optional[List[float]] = None,
58+
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
5859
) -> "Node":
5960
return cls(
6061
value=value,
@@ -120,7 +121,7 @@ def __init__(
120121
self,
121122
tree_structure: Dict[int, Node],
122123
output: npt.NDArray[np.float_],
123-
split_rules: List[object],
124+
split_rules: List[SplitRule],
124125
idx_leaf_nodes: Optional[List[int]] = None,
125126
) -> None:
126127
self.tree_structure = tree_structure
@@ -135,7 +136,7 @@ def new_tree(
135136
idx_data_points: Optional[npt.NDArray[np.int_]],
136137
num_observations: int,
137138
shape: int,
138-
split_rules: List[object],
139+
split_rules: List[SplitRule],
139140
) -> "Tree":
140141
return cls(
141142
tree_structure={
@@ -258,7 +259,7 @@ def _traverse_tree(
258259
self,
259260
X: npt.NDArray[np.float_],
260261
excluded: Optional[List[int]] = None,
261-
shape: int = 1,
262+
shape: Union[int,Tuple[int,...]] = 1,
262263
) -> npt.NDArray[np.float_]:
263264
"""
264265
Traverse the tree starting from the root node given an (un)observed point.

0 commit comments

Comments
 (0)