Skip to content

mypy init #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ jobs:
echo "Success!"
echo "Checking code style with pylint..."
python -m pylint pymc_bart/
- name: Run Mypy
shell: bash -l {0}
run: |
python -m mypy pymc_bart
- name: Run tests
shell: bash -l {0}
run: |
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ repos:
language: system
types: [python]
files: ^pymc_bart/
- id: mypy
name: mypy
entry: mypy
language: system
types: [python]
files: ^pymc_bart/
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
Expand Down
15 changes: 15 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[mypy]
files = pymc_bart/*.py
plugins = numpy.typing.mypy_plugin

[mypy-matplotlib.*]
ignore_missing_imports = True

[mypy-numba.*]
ignore_missing_imports = True

[mypy-pymc.*]
ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True
42 changes: 23 additions & 19 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,31 @@
# limitations under the License.

from multiprocessing import Manager
from typing import List, Optional, Tuple

import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
from pandas import DataFrame, Series

from pymc.distributions.distribution import Distribution, _moment
from pymc.logprob.abstract import _logprob
import pytensor.tensor as pt
from pytensor.tensor.random.op import RandomVariable


from .utils import _sample_posterior
from .tree import Tree
from .utils import TensorLike, _sample_posterior

__all__ = ["BART"]


class BARTRV(RandomVariable):
"""Base class for BART."""

name = "BART"
name: str = "BART"
ndim_supp = 1
ndims_params = [2, 1, 0, 0, 1]
dtype = "floatX"
_print_name = ("BART", "\\operatorname{BART}")
all_trees = None
ndims_params: List[int] = [2, 1, 0, 0, 1]
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = List[List[Tree]]

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return dist_params[0].shape[:1]
Expand All @@ -64,29 +66,29 @@ class BART(Distribution):

Parameters
----------
X : array-like
X : TensorLike
The covariate matrix.
Y : array-like
Y : TensorLike
The response vector.
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Even when it can takes values in
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
split_prior : array-like
split_prior : Optional[List[float]], default None.
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
"""

def __new__(
cls,
name,
X,
Y,
m=50,
alpha=0.25,
split_prior=None,
name: str,
X: TensorLike,
Y: TensorLike,
m: int = 50,
alpha: float = 0.25,
split_prior: Optional[List[float]] = None,
**kwargs,
):
manager = Manager()
Expand Down Expand Up @@ -147,7 +149,9 @@ def get_moment(cls, rv, size, *rv_inputs):
return mean


def preprocess_xy(X, Y):
def preprocess_xy(
X: TensorLike, Y: TensorLike
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
Expand Down
Loading