|
17 | 17 | from multiprocessing import Manager
|
18 | 18 |
|
19 | 19 | import numpy as np
|
| 20 | +import numpy.typing as npt |
20 | 21 | import pytensor.tensor as pt
|
21 | 22 | from pandas import DataFrame, Series
|
22 | 23 | from pymc.distributions.distribution import Distribution, _moment
|
23 | 24 | from pymc.logprob.abstract import _logprob
|
24 | 25 | from pytensor.tensor.random.op import RandomVariable
|
| 26 | +from typing import List, Optional, Tuple |
| 27 | + |
| 28 | + |
| 29 | +from .utils import _sample_posterior, TensorLike |
| 30 | + |
25 | 31 |
|
26 |
| -from .utils import _sample_posterior |
27 | 32 |
|
28 | 33 | __all__ = ["BART"]
|
29 | 34 |
|
30 | 35 |
|
31 | 36 | class BARTRV(RandomVariable):
|
32 | 37 | """Base class for BART."""
|
33 | 38 |
|
34 |
| - name = "BART" |
| 39 | + name: str = "BART" |
35 | 40 | ndim_supp = 1
|
36 |
| - ndims_params = [2, 1, 0, 0, 1] |
37 |
| - dtype = "floatX" |
38 |
| - _print_name = ("BART", "\\operatorname{BART}") |
| 41 | + ndims_params: List[int] = [2, 1, 0, 0, 1] |
| 42 | + dtype: str = "floatX" |
| 43 | + _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") |
39 | 44 | all_trees = None
|
40 | 45 |
|
41 | 46 | def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
|
@@ -63,29 +68,29 @@ class BART(Distribution):
|
63 | 68 |
|
64 | 69 | Parameters
|
65 | 70 | ----------
|
66 |
| - X : array-like |
| 71 | + X : TensorLike |
67 | 72 | The covariate matrix.
|
68 |
| - Y : array-like |
| 73 | + Y : TensorLike |
69 | 74 | The response vector.
|
70 | 75 | m : int
|
71 | 76 | Number of trees
|
72 | 77 | alpha : float
|
73 | 78 | Control the prior probability over the depth of the trees. Even when it can takes values in
|
74 | 79 | the interval (0, 1), it is recommended to be in the interval (0, 0.5].
|
75 |
| - split_prior : array-like |
| 80 | + split_prior : Optional[List[float]], default None. |
76 | 81 | Each element of split_prior should be in the [0, 1] interval and the elements should sum to
|
77 | 82 | 1. Otherwise they will be normalized.
|
78 | 83 | Defaults to 0, i.e. all covariates have the same prior probability to be selected.
|
79 | 84 | """
|
80 | 85 |
|
81 | 86 | def __new__(
|
82 | 87 | cls,
|
83 |
| - name, |
84 |
| - X, |
85 |
| - Y, |
86 |
| - m=50, |
87 |
| - alpha=0.25, |
88 |
| - split_prior=None, |
| 88 | + name: str, |
| 89 | + X: TensorLike, |
| 90 | + Y: TensorLike, |
| 91 | + m: int = 50, |
| 92 | + alpha: float = 0.25, |
| 93 | + split_prior: Optional[List[float]] = None, |
89 | 94 | **kwargs,
|
90 | 95 | ):
|
91 | 96 | manager = Manager()
|
@@ -146,7 +151,9 @@ def get_moment(cls, rv, size, *rv_inputs):
|
146 | 151 | return mean
|
147 | 152 |
|
148 | 153 |
|
149 |
| -def preprocess_xy(X, Y): |
| 154 | +def preprocess_xy( |
| 155 | + X: TensorLike, Y: TensorLike |
| 156 | +) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]: |
150 | 157 | if isinstance(Y, (Series, DataFrame)):
|
151 | 158 | Y = Y.to_numpy()
|
152 | 159 | if isinstance(X, (Series, DataFrame)):
|
|
0 commit comments