13
13
# limitations under the License.
14
14
15
15
from functools import lru_cache
16
- from typing import Dict , Generator , List , Optional
16
+ from typing import Dict , Generator , List , Optional , Tuple , Union
17
17
18
18
import numpy as np
19
19
import numpy .typing as npt
20
20
from pytensor import config
21
+ from .split_rules import SplitRule
21
22
22
23
23
24
class Node :
@@ -39,7 +40,7 @@ def __init__(
39
40
nvalue : int = 0 ,
40
41
idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
41
42
idx_split_variable : int = - 1 ,
42
- linear_params : Optional [List [float ]] = None ,
43
+ linear_params : Optional [List [npt . NDArray [ np . float_ ] ]] = None ,
43
44
) -> None :
44
45
self .value = value
45
46
self .nvalue = nvalue
@@ -54,7 +55,7 @@ def new_leaf_node(
54
55
nvalue : int = 0 ,
55
56
idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
56
57
idx_split_variable : int = - 1 ,
57
- linear_params : Optional [List [float ]] = None ,
58
+ linear_params : Optional [List [npt . NDArray [ np . float_ ] ]] = None ,
58
59
) -> "Node" :
59
60
return cls (
60
61
value = value ,
@@ -120,7 +121,7 @@ def __init__(
120
121
self ,
121
122
tree_structure : Dict [int , Node ],
122
123
output : npt .NDArray [np .float_ ],
123
- split_rules : List [object ],
124
+ split_rules : List [SplitRule ],
124
125
idx_leaf_nodes : Optional [List [int ]] = None ,
125
126
) -> None :
126
127
self .tree_structure = tree_structure
@@ -135,7 +136,7 @@ def new_tree(
135
136
idx_data_points : Optional [npt .NDArray [np .int_ ]],
136
137
num_observations : int ,
137
138
shape : int ,
138
- split_rules : List [object ],
139
+ split_rules : List [SplitRule ],
139
140
) -> "Tree" :
140
141
return cls (
141
142
tree_structure = {
@@ -258,7 +259,7 @@ def _traverse_tree(
258
259
self ,
259
260
X : npt .NDArray [np .float_ ],
260
261
excluded : Optional [List [int ]] = None ,
261
- shape : int = 1 ,
262
+ shape : Union [ int , Tuple [ int ,...]] = 1 ,
262
263
) -> npt .NDArray [np .float_ ]:
263
264
"""
264
265
Traverse the tree starting from the root node given an (un)observed point.
0 commit comments