Skip to content

Commit 846b0b8

Browse files
author
juanitorduz
committed
add types utils and bart
1 parent 38436a9 commit 846b0b8

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

pymc_bart/bart.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,30 @@
1717
from multiprocessing import Manager
1818

1919
import numpy as np
20+
import numpy.typing as npt
2021
import pytensor.tensor as pt
2122
from pandas import DataFrame, Series
2223
from pymc.distributions.distribution import Distribution, _moment
2324
from pymc.logprob.abstract import _logprob
2425
from pytensor.tensor.random.op import RandomVariable
26+
from typing import List, Optional, Tuple
27+
28+
29+
from .utils import _sample_posterior, TensorLike
30+
2531

26-
from .utils import _sample_posterior
2732

2833
__all__ = ["BART"]
2934

3035

3136
class BARTRV(RandomVariable):
3237
"""Base class for BART."""
3338

34-
name = "BART"
39+
name: str = "BART"
3540
ndim_supp = 1
36-
ndims_params = [2, 1, 0, 0, 1]
37-
dtype = "floatX"
38-
_print_name = ("BART", "\\operatorname{BART}")
41+
ndims_params: List[int] = [2, 1, 0, 0, 1]
42+
dtype: str = "floatX"
43+
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
3944
all_trees = None
4045

4146
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
@@ -63,29 +68,29 @@ class BART(Distribution):
6368
6469
Parameters
6570
----------
66-
X : array-like
71+
X : TensorLike
6772
The covariate matrix.
68-
Y : array-like
73+
Y : TensorLike
6974
The response vector.
7075
m : int
7176
Number of trees
7277
alpha : float
7378
Control the prior probability over the depth of the trees. Even when it can takes values in
7479
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
75-
split_prior : array-like
80+
split_prior : Optional[List[float]], default None.
7681
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
7782
1. Otherwise they will be normalized.
7883
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
7984
"""
8085

8186
def __new__(
8287
cls,
83-
name,
84-
X,
85-
Y,
86-
m=50,
87-
alpha=0.25,
88-
split_prior=None,
88+
name: str,
89+
X: TensorLike,
90+
Y: TensorLike,
91+
m: int = 50,
92+
alpha: float = 0.25,
93+
split_prior: Optional[List[float]] = None,
8994
**kwargs,
9095
):
9196
manager = Manager()
@@ -146,7 +151,9 @@ def get_moment(cls, rv, size, *rv_inputs):
146151
return mean
147152

148153

149-
def preprocess_xy(X, Y):
154+
def preprocess_xy(
155+
X: TensorLike, Y: TensorLike
156+
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
150157
if isinstance(Y, (Series, DataFrame)):
151158
Y = Y.to_numpy()
152159
if isinstance(X, (Series, DataFrame)):

pymc_bart/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,29 @@
22

33
import warnings
44

5+
import pytensor.tensor as pt
56
import arviz as az
67
import matplotlib.pyplot as plt
78
import numpy as np
9+
import numpy.typing as npt
810
from pytensor.tensor.var import Variable
911
from scipy.interpolate import griddata
1012
from scipy.signal import savgol_filter
1113
from scipy.stats import norm, pearsonr
14+
from typing import List, Optional, Tuple, Union
15+
from .tree import Tree
1216

1317

14-
def _sample_posterior(all_trees, X, rng, size=None, excluded=None):
18+
TensorLike = Union[npt.NDArray[np.float_], pt.TensorVariable]
19+
20+
21+
def _sample_posterior(
22+
all_trees: List[List[Tree]],
23+
X: TensorLike,
24+
rng: np.random.Generator,
25+
size=Optional[Union[int, Tuple[int, ...]]],
26+
excluded: Optional[List[int]] = None,
27+
) -> npt.NDArray[np.float_]:
1528
"""
1629
Generate samples from the BART-posterior.
1730

0 commit comments

Comments
 (0)