Skip to content

Commit be9dbf0

Browse files
Juan Orduzaloctavodia
Juan Orduz
andauthored
mypy init (#78)
* mypy init * undo change * fix tree.py types * small code improvements * fix batch type * add some more types * Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * fix ident * more hints * add types utils and bart * fix lint * fix some hints * fix sort * remove outout None type * fix predict output type * fix module path to make it compatible with pymc5.2.0 * type hints to plots * unpin mypy specifict version * Update pymc_bart/utils.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * Update pymc_bart/tree.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * Update pymc_bart/tree.py Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> * make idx_leaf_nodes list instead of numpy array * fix plot figsize type * fix exclude type --------- Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
1 parent 6e82299 commit be9dbf0

File tree

12 files changed

+392
-284
lines changed

12 files changed

+392
-284
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ jobs:
4343
echo "Success!"
4444
echo "Checking code style with pylint..."
4545
python -m pylint pymc_bart/
46+
- name: Run Mypy
47+
shell: bash -l {0}
48+
run: |
49+
python -m mypy pymc_bart
4650
- name: Run tests
4751
shell: bash -l {0}
4852
run: |

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ repos:
1313
language: system
1414
types: [python]
1515
files: ^pymc_bart/
16+
- id: mypy
17+
name: mypy
18+
entry: mypy
19+
language: system
20+
types: [python]
21+
files: ^pymc_bart/
1622
- repo: https://github.com/pre-commit/pre-commit-hooks
1723
rev: v3.2.0
1824
hooks:

mypy.ini

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[mypy]
2+
files = pymc_bart/*.py
3+
plugins = numpy.typing.mypy_plugin
4+
5+
[mypy-matplotlib.*]
6+
ignore_missing_imports = True
7+
8+
[mypy-numba.*]
9+
ignore_missing_imports = True
10+
11+
[mypy-pymc.*]
12+
ignore_missing_imports = True
13+
14+
[mypy-scipy.*]
15+
ignore_missing_imports = True

pymc_bart/bart.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,31 @@
1515
# limitations under the License.
1616

1717
from multiprocessing import Manager
18+
from typing import List, Optional, Tuple
19+
1820
import numpy as np
21+
import numpy.typing as npt
22+
import pytensor.tensor as pt
1923
from pandas import DataFrame, Series
20-
2124
from pymc.distributions.distribution import Distribution, _moment
2225
from pymc.logprob.abstract import _logprob
23-
import pytensor.tensor as pt
2426
from pytensor.tensor.random.op import RandomVariable
2527

26-
27-
from .utils import _sample_posterior
28+
from .tree import Tree
29+
from .utils import TensorLike, _sample_posterior
2830

2931
__all__ = ["BART"]
3032

3133

3234
class BARTRV(RandomVariable):
3335
"""Base class for BART."""
3436

35-
name = "BART"
37+
name: str = "BART"
3638
ndim_supp = 1
37-
ndims_params = [2, 1, 0, 0, 1]
38-
dtype = "floatX"
39-
_print_name = ("BART", "\\operatorname{BART}")
40-
all_trees = None
39+
ndims_params: List[int] = [2, 1, 0, 0, 1]
40+
dtype: str = "floatX"
41+
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
42+
all_trees = List[List[Tree]]
4143

4244
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
4345
return dist_params[0].shape[:1]
@@ -64,29 +66,29 @@ class BART(Distribution):
6466
6567
Parameters
6668
----------
67-
X : array-like
69+
X : TensorLike
6870
The covariate matrix.
69-
Y : array-like
71+
Y : TensorLike
7072
The response vector.
7173
m : int
7274
Number of trees
7375
alpha : float
7476
Control the prior probability over the depth of the trees. Even when it can takes values in
7577
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
76-
split_prior : array-like
78+
split_prior : Optional[List[float]], default None.
7779
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
7880
1. Otherwise they will be normalized.
7981
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
8082
"""
8183

8284
def __new__(
8385
cls,
84-
name,
85-
X,
86-
Y,
87-
m=50,
88-
alpha=0.25,
89-
split_prior=None,
86+
name: str,
87+
X: TensorLike,
88+
Y: TensorLike,
89+
m: int = 50,
90+
alpha: float = 0.25,
91+
split_prior: Optional[List[float]] = None,
9092
**kwargs,
9193
):
9294
manager = Manager()
@@ -147,7 +149,9 @@ def get_moment(cls, rv, size, *rv_inputs):
147149
return mean
148150

149151

150-
def preprocess_xy(X, Y):
152+
def preprocess_xy(
153+
X: TensorLike, Y: TensorLike
154+
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
151155
if isinstance(Y, (Series, DataFrame)):
152156
Y = Y.to_numpy()
153157
if isinstance(X, (Series, DataFrame)):

0 commit comments

Comments
 (0)