diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3b5a5df..db9f630 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -43,6 +43,10 @@ jobs: echo "Success!" echo "Checking code style with pylint..." python -m pylint pymc_bart/ + - name: Run Mypy + shell: bash -l {0} + run: | + python -m mypy pymc_bart - name: Run tests shell: bash -l {0} run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da409e7..5432a77 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,6 +13,12 @@ repos: language: system types: [python] files: ^pymc_bart/ + - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + files: ^pymc_bart/ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..56088d7 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,15 @@ +[mypy] +files = pymc_bart/*.py +plugins = numpy.typing.mypy_plugin + +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-numba.*] +ignore_missing_imports = True + +[mypy-pymc.*] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index dad53d6..b7f688f 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -15,16 +15,18 @@ # limitations under the License. from multiprocessing import Manager +from typing import List, Optional, Tuple + import numpy as np +import numpy.typing as npt +import pytensor.tensor as pt from pandas import DataFrame, Series - from pymc.distributions.distribution import Distribution, _moment from pymc.logprob.abstract import _logprob -import pytensor.tensor as pt from pytensor.tensor.random.op import RandomVariable - -from .utils import _sample_posterior +from .tree import Tree +from .utils import TensorLike, _sample_posterior __all__ = ["BART"] @@ -32,12 +34,12 @@ class BARTRV(RandomVariable): """Base class for BART.""" - name = "BART" + name: str = "BART" ndim_supp = 1 - ndims_params = [2, 1, 0, 0, 1] - dtype = "floatX" - _print_name = ("BART", "\\operatorname{BART}") - all_trees = None + ndims_params: List[int] = [2, 1, 0, 0, 1] + dtype: str = "floatX" + _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") + all_trees = List[List[Tree]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): return dist_params[0].shape[:1] @@ -64,16 +66,16 @@ class BART(Distribution): Parameters ---------- - X : array-like + X : TensorLike The covariate matrix. - Y : array-like + Y : TensorLike The response vector. m : int Number of trees alpha : float Control the prior probability over the depth of the trees. Even when it can takes values in the interval (0, 1), it is recommended to be in the interval (0, 0.5]. - split_prior : array-like + split_prior : Optional[List[float]], default None. Each element of split_prior should be in the [0, 1] interval and the elements should sum to 1. Otherwise they will be normalized. Defaults to 0, i.e. all covariates have the same prior probability to be selected. @@ -81,12 +83,12 @@ class BART(Distribution): def __new__( cls, - name, - X, - Y, - m=50, - alpha=0.25, - split_prior=None, + name: str, + X: TensorLike, + Y: TensorLike, + m: int = 50, + alpha: float = 0.25, + split_prior: Optional[List[float]] = None, **kwargs, ): manager = Manager() @@ -147,7 +149,9 @@ def get_moment(cls, rv, size, *rv_inputs): return mean -def preprocess_xy(X, Y): +def preprocess_xy( + X: TensorLike, Y: TensorLike +) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]: if isinstance(Y, (Series, DataFrame)): Y = Y.to_numpy() if isinstance(X, (Series, DataFrame)): diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 6755836..c093d12 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,21 +12,76 @@ # See the License for the specific language governing permissions and # limitations under the License. -from numba import njit - +from typing import List, Optional, Tuple, Union +import numpy.typing as npt import numpy as np - -from pytensor import function as pytensor_function -from pytensor import config -from pytensor.tensor.var import Variable - -from pymc.model import modelcontext +from numba import njit +from pymc.model import Model, modelcontext +from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared from pymc.step_methods.compound import Competence -from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements +from pytensor import config +from pytensor import function as pytensor_function +from pytensor.tensor.var import Variable from pymc_bart.bart import BARTRV -from pymc_bart.tree import Tree, Node, get_depth +from pymc_bart.tree import Node, Tree, get_depth + + +class ParticleTree: + """Particle tree.""" + + __slots__ = "tree", "expansion_nodes", "log_weight", "kfactor" + + def __init__(self, tree: Tree, kfactor: float = 0.75): + self.tree: Tree = tree.copy() + self.expansion_nodes: List[int] = [0] + self.log_weight: float = 0 + self.kfactor: float = kfactor + + def copy(self) -> "ParticleTree": + p = ParticleTree(self.tree) + p.expansion_nodes = self.expansion_nodes.copy() + p.kfactor = self.kfactor + return p + + def sample_tree( + self, + ssv, + available_predictors, + prior_prob_leaf_node, + X, + missing_data, + sum_trees, + m, + normal, + shape, + ) -> bool: + tree_grew = False + if self.expansion_nodes: + index_leaf_node = self.expansion_nodes.pop(0) + # Probability that this node will remain a leaf node + prob_leaf = prior_prob_leaf_node[get_depth(index_leaf_node)] + + if prob_leaf < np.random.random(): + idx_new_nodes = grow_tree( + self.tree, + index_leaf_node, + ssv, + available_predictors, + X, + missing_data, + sum_trees, + m, + normal, + self.kfactor, + shape, + ) + if idx_new_nodes is not None: + self.expansion_nodes.extend(idx_new_nodes) + tree_grew = True + + return tree_grew class PGBART(ArrayStepShared): @@ -55,9 +110,9 @@ class PGBART(ArrayStepShared): def __init__( self, vars=None, # pylint: disable=redefined-builtin - num_particles=20, - batch="auto", - model=None, + num_particles: int = 20, + batch: Tuple[float, float] = (0.1, 0.1), + model: Optional[Model] = None, ): model = modelcontext(model) initial_values = model.initial_point() @@ -77,10 +132,7 @@ def __init__( self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m shape = initial_values[value_bart.name].shape - if len(shape) == 1: - self.shape = 1 - else: - self.shape = shape[0] + self.shape = 1 if len(shape) == 1 else shape[0] if self.bart.split_prior: self.alpha_vec = self.bart.split_prior @@ -116,20 +168,15 @@ def __init__( self.tune = True - if batch == "auto": - batch = max(1, int(self.m * 0.1)) - self.batch = (batch, batch) - else: - if isinstance(batch, (tuple, list)): - self.batch = batch - else: - self.batch = (batch, batch) + batch_0 = max(1, int(self.m * batch[0])) + batch_1 = max(1, int(self.m * batch[1])) + self.batch = (batch_0, batch_1) self.num_particles = num_particles self.indices = list(range(1, num_particles)) shared = make_shared_replacements(initial_values, vars, model) self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared) - self.all_particles = list(ParticleTree(self.a_tree) for _ in range(self.m)) + self.all_particles = [ParticleTree(self.a_tree) for _ in range(self.m)] self.all_trees = np.array([p.tree for p in self.all_particles]) self.lower = 0 self.iter = 0 @@ -154,7 +201,7 @@ def astep(self, _): # Sample each particle (try to grow each tree), except for the first one stop_growing = True for p in particles[1:]: - tree_grew = p.sample_tree( + if p.sample_tree( self.ssv, self.available_predictors, self.prior_prob_leaf_node, @@ -164,8 +211,7 @@ def astep(self, _): self.m, self.normal, self.shape, - ) - if tree_grew: + ): self.update_weight(p) if p.expansion_nodes: stop_growing = False @@ -205,7 +251,7 @@ def astep(self, _): stats = {"variable_inclusion": variable_inclusion, "tune": self.tune} return self.sum_trees, [stats] - def normalize(self, particles): + def normalize(self, particles: List[ParticleTree]) -> float: """ Use softmax to get normalized_weights. """ @@ -215,15 +261,17 @@ def normalize(self, particles): wei = np.exp(log_w_) + 1e-12 return wei / wei.sum() - def resample(self, particles, normalized_weights): + def resample( + self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_] + ) -> List[ParticleTree]: """ Use systematic resample for all but the first particle Ensure particles are copied only if needed. """ new_indices = self.systematic(normalized_weights) + 1 - seen = [] - new_particles = [] + seen: List[int] = [] + new_particles: List[ParticleTree] = [] for idx in new_indices: if idx in seen: new_particles.append(particles[idx].copy()) @@ -235,7 +283,9 @@ def resample(self, particles, normalized_weights): return particles - def get_particle_tree(self, particles, normalized_weights): + def get_particle_tree( + self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_] + ) -> Tuple[ParticleTree, Tree]: """ Sample a new particle and associated tree """ @@ -246,7 +296,7 @@ def get_particle_tree(self, particles, normalized_weights): return new_particle, new_particle.tree - def systematic(self, normalized_weights): + def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[np.int_]: """ Systematic resampling. @@ -258,21 +308,20 @@ def systematic(self, normalized_weights): single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw return inverse_cdf(single_uniform, normalized_weights) - def init_particles(self, tree_id: int) -> np.ndarray: + def init_particles(self, tree_id: int) -> List[ParticleTree]: """Initialize particles.""" - p0 = self.all_particles[tree_id] + p0: ParticleTree = self.all_particles[tree_id] # The old tree does not grow so we update the weight only once self.update_weight(p0) - particles = [p0] - - for _ in self.indices: - particles.append( - ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor) - ) + particles: List[ParticleTree] = [p0] + particles.extend( + ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor) + for _ in self.indices + ) return particles - def update_weight(self, particle): + def update_weight(self, particle: ParticleTree) -> None: """ Update the weight of a particle. """ @@ -290,64 +339,8 @@ def competence(var, has_grad): return Competence.INCOMPATIBLE -class ParticleTree: - """Particle tree.""" - - __slots__ = "tree", "expansion_nodes", "log_weight", "kfactor" - - def __init__(self, tree, kfactor=0.75): - self.tree = tree.copy() - self.expansion_nodes = [0] - self.log_weight = 0 - self.kfactor = kfactor - - def copy(self): - p = ParticleTree(self.tree) - p.expansion_nodes = self.expansion_nodes.copy() - p.kfactor = self.kfactor - return p - - def sample_tree( - self, - ssv, - available_predictors, - prior_prob_leaf_node, - X, - missing_data, - sum_trees, - m, - normal, - shape, - ): - tree_grew = False - if self.expansion_nodes: - index_leaf_node = self.expansion_nodes.pop(0) - # Probability that this node will remain a leaf node - prob_leaf = prior_prob_leaf_node[get_depth(index_leaf_node)] - - if prob_leaf < np.random.random(): - idx_new_nodes = grow_tree( - self.tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees, - m, - normal, - self.kfactor, - shape, - ) - if idx_new_nodes is not None: - self.expansion_nodes.extend(idx_new_nodes) - tree_grew = True - - return tree_grew - - class SampleSplittingVariable: - def __init__(self, alpha_vec): + def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None: """ Sample splitting variables proportional to `alpha_vec`. @@ -356,15 +349,15 @@ def __init__(self, alpha_vec): """ self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) - def rvs(self): - rnd = np.random.random() + def rvs(self) -> Union[int, Tuple[int, float]]: + rnd: float = np.random.random() for i, val in self.enu: if rnd <= val: return i return self.enu[-1] -def compute_prior_probability(alpha): +def compute_prior_probability(alpha) -> List[float]: """ Calculate the probability of the node being a leaf node (1 - p(being split node)). @@ -383,7 +376,7 @@ def compute_prior_probability(alpha): .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. arXiv, `link `__ """ - prior_leaf_prob = [0] + prior_leaf_prob: List[float] = [0] depth = 1 while prior_leaf_prob[-1] < 1: prior_leaf_prob.append(1 - alpha**depth) @@ -422,7 +415,7 @@ def grow_tree( current_node.get_idx_right_child(), ) - new_nodes = [] + new_nodes = np.empty(2, dtype=object) for idx in range(2): idx_data_point = new_idx_data_points[idx] node_value = draw_leaf_value( @@ -437,7 +430,7 @@ def grow_tree( value=node_value, idx_data_points=idx_data_point, ) - new_nodes.append(new_node) + new_nodes[idx] = new_node tree.grow_leaf_node(current_node, selected_predictor, split_value, index_leaf_node) tree.set_node(new_nodes[0].index, new_nodes[0]) @@ -560,17 +553,19 @@ def update(self): @njit -def inverse_cdf(single_uniform, normalized_weights): +def inverse_cdf( + single_uniform: npt.NDArray[np.float_], normalized_weights: npt.NDArray[np.float_] +) -> npt.NDArray[np.int_]: """ Inverse CDF algorithm for a finite distribution. Parameters ---------- - single_uniform: ndarray - ordered points in [0,1] + single_uniform: npt.NDArray[np.float_] + Ordered points in [0,1] - normalized_weights: ndarray - normalized weights + normalized_weights: npt.NDArray[np.float_]) + Normalized weights Returns ------- diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 4f98709..05f4301 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -13,11 +13,65 @@ # limitations under the License. import math - from functools import lru_cache +from typing import Dict, Generator, List, Optional -from pytensor import config import numpy as np +import numpy.typing as npt +from pytensor import config + + +class Node: + """Node of a binary tree. + + Attributes + ---------- + index : int + value : float + idx_data_points : Optional[npt.NDArray[np.int_]] + idx_split_variable : Optional[npt.NDArray[np.int_]] + """ + + __slots__ = "index", "value", "idx_split_variable", "idx_data_points" + + def __init__( + self, + index: int, + value: float = -1.0, + idx_data_points: Optional[npt.NDArray[np.int_]] = None, + idx_split_variable: int = -1, + ) -> None: + self.index = index + self.value = value + self.idx_data_points = idx_data_points + self.idx_split_variable = idx_split_variable + + @classmethod + def new_leaf_node( + cls, index: int, value: float, idx_data_points: Optional[npt.NDArray[np.int_]] + ) -> "Node": + return cls(index, value=value, idx_data_points=idx_data_points) + + @classmethod + def new_split_node(cls, index: int, split_value: float, idx_split_variable: int) -> "Node": + return cls(index=index, value=split_value, idx_split_variable=idx_split_variable) + + def get_idx_left_child(self) -> int: + return self.index * 2 + 1 + + def get_idx_right_child(self) -> int: + return self.index * 2 + 2 + + def is_split_node(self) -> bool: + return self.idx_split_variable >= 0 + + def is_leaf_node(self) -> bool: + return not self.is_split_node() + + +@lru_cache +def get_depth(index: int) -> int: + return math.floor(math.log2(index + 1)) class Tree: @@ -28,149 +82,178 @@ class Tree: Attributes ---------- - tree_structure : dict + tree_structure : Dict[int, Node] A dictionary that represents the nodes stored in breadth-first order, based in the array method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). The dictionary's keys are integers that represent the nodes position. The dictionary's values are objects of type Node that represent the split and leaf nodes of the tree itself. - idx_leaf_nodes : list - List with the index of the leaf nodes of the tree. - output: array + output: Optional[npt.NDArray[np.float_]] Array of shape number of observations, shape + idx_leaf_nodes : Optional[List[int]], by default None. + Array with the index of the leaf nodes of the tree. Parameters ---------- tree_structure : Dictionary of nodes - idx_leaf_nodes : List with the index of the leaf nodes of the tree. output : Array of shape number of observations, shape + idx_leaf_nodes : List with the index of the leaf nodes of the tree. """ __slots__ = ( "tree_structure", - "idx_leaf_nodes", "output", + "idx_leaf_nodes", ) - def __init__(self, tree_structure, idx_leaf_nodes, output): + def __init__( + self, + tree_structure: Dict[int, Node], + output: npt.NDArray[np.float_], + idx_leaf_nodes: Optional[List[int]] = None, + ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes self.output = output @classmethod - def new_tree(cls, leaf_node_value, idx_data_points, num_observations, shape): + def new_tree( + cls, + leaf_node_value: float, + idx_data_points: Optional[npt.NDArray[np.int_]], + num_observations: int, + shape: int, + ) -> "Tree": return cls( tree_structure={ - 0: Node.new_leaf_node(0, value=leaf_node_value, idx_data_points=idx_data_points) + 0: Node.new_leaf_node( + index=0, value=leaf_node_value, idx_data_points=idx_data_points + ) }, idx_leaf_nodes=[0], output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(), ) - def __getitem__(self, index): + def __getitem__(self, index) -> Node: return self.get_node(index) - def __setitem__(self, index, node): + def __setitem__(self, index, node) -> None: self.set_node(index, node) - def copy(self): - tree = { + def copy(self) -> "Tree": + tree: Dict[int, Node] = { k: Node(v.index, v.value, v.idx_data_points, v.idx_split_variable) for k, v in self.tree_structure.items() } - return Tree(tree, self.idx_leaf_nodes.copy(), self.output.copy()) + idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None + return Tree(tree_structure=tree, idx_leaf_nodes=idx_leaf_nodes, output=self.output) - def get_node(self, index) -> "Node": + def get_node(self, index: int) -> Node: return self.tree_structure[index] - def set_node(self, index, node): + def set_node(self, index: int, node: Node) -> None: self.tree_structure[index] = node - if node.is_leaf_node(): + if node.is_leaf_node() and self.idx_leaf_nodes is not None: self.idx_leaf_nodes.append(index) - def grow_leaf_node(self, current_node, selected_predictor, split_value, index_leaf_node): + def grow_leaf_node( + self, current_node: Node, selected_predictor: int, split_value: float, index_leaf_node: int + ) -> None: current_node.value = split_value current_node.idx_split_variable = selected_predictor current_node.idx_data_points = None - self.idx_leaf_nodes.remove(index_leaf_node) + if self.idx_leaf_nodes is not None: + self.idx_leaf_nodes.remove(index_leaf_node) - def trim(self): - tree = { + def trim(self) -> "Tree": + tree: Dict[int, Node] = { k: Node(v.index, v.value, None, v.idx_split_variable) for k, v in self.tree_structure.items() } - return Tree(tree, None, None) + return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1])) - def get_split_variables(self): + def get_split_variables(self) -> Generator[int, None, None]: for node in self.tree_structure.values(): if node.is_split_node(): yield node.idx_split_variable - def _predict(self): + def _predict(self) -> npt.NDArray[np.float_]: output = self.output - for node_index in self.idx_leaf_nodes: - leaf_node = self.get_node(node_index) - output[leaf_node.idx_data_points] = leaf_node.value + + if self.idx_leaf_nodes is not None: + for node_index in self.idx_leaf_nodes: + leaf_node = self.get_node(node_index) + output[leaf_node.idx_data_points] = leaf_node.value return output.T - def predict(self, x, excluded=None): + def predict( + self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None + ) -> npt.NDArray[np.float_]: """ Predict output of tree for an (un)observed point x. Parameters ---------- - x : numpy array + x : npt.NDArray[np.float_] Unobserved point - excluded: list - Indexes of the variables to exclude when computing predictions + excluded: Optional[List[int]] + Indexes of the variables to exclude when computing predictions Returns ------- - float + npt.NDArray[np.float_] Value of the leaf value where the unobserved point lies. """ if excluded is None: excluded = [] return self._traverse_tree(x, 0, excluded) - def _traverse_tree(self, x, node_index, excluded): + def _traverse_tree( + self, + x: npt.NDArray[np.float_], + node_index: int, + excluded: Optional[List[int]] = None, + ) -> npt.NDArray[np.float_]: """ Traverse the tree starting from a particular node given an unobserved point. Parameters ---------- - x : np.ndarray + x : npt.NDArray[np.float_] + Unobserved point node_index : int + Index of the node to start the traversal from + excluded: Optional[List[int]] + Indexes of the variables to exclude when computing predictions Returns ------- - Leaf node value or mean of leaf node values + npt.NDArray[np.float_] + Leaf node value or mean of leaf node values """ - current_node = self.get_node(node_index) + current_node: Node = self.get_node(node_index) if current_node.is_leaf_node(): - return current_node.value - if current_node.idx_split_variable in excluded: - leaf_values = [] + return np.array(current_node.value) + + if excluded is not None and current_node.idx_split_variable in excluded: + leaf_values: List[float] = [] self._traverse_leaf_values(leaf_values, node_index) - return np.mean(leaf_values, 0) + return np.mean(leaf_values, axis=0) if x[current_node.idx_split_variable] <= current_node.value: next_node = current_node.get_idx_left_child() else: next_node = current_node.get_idx_right_child() - return self._traverse_tree(x, next_node, excluded) + return self._traverse_tree(x=x, node_index=next_node, excluded=excluded) - def _traverse_leaf_values(self, leaf_values, node_index): + def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- + leaf_values : List[float] node_index : int - - Returns - ------- - List of leaf node values """ node = self.get_node(node_index) if node.is_leaf_node(): @@ -178,38 +261,3 @@ def _traverse_leaf_values(self, leaf_values, node_index): else: self._traverse_leaf_values(leaf_values, node.get_idx_left_child()) self._traverse_leaf_values(leaf_values, node.get_idx_right_child()) - - -class Node: - __slots__ = "index", "value", "idx_split_variable", "idx_data_points" - - def __init__(self, index: int, value=-1, idx_data_points=None, idx_split_variable=-1): - self.index = index - self.value = value - self.idx_data_points = idx_data_points - self.idx_split_variable = idx_split_variable - - @classmethod - def new_leaf_node(cls, index: int, value, idx_data_points) -> "Node": - return cls(index, value=value, idx_data_points=idx_data_points) - - @classmethod - def new_split_node(cls, index: int, split_value, idx_split_variable) -> "Node": - return cls(index, value=split_value, idx_split_variable=idx_split_variable) - - def get_idx_left_child(self) -> int: - return self.index * 2 + 1 - - def get_idx_right_child(self) -> int: - return self.index * 2 + 2 - - def is_split_node(self) -> bool: - return self.idx_split_variable >= 0 - - def is_leaf_node(self) -> bool: - return not self.is_split_node() - - -@lru_cache -def get_depth(index: int) -> int: - return math.floor(math.log2(index + 1)) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index a693955..45b9a75 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1,18 +1,30 @@ """Utility function for variable selection and bart interpretability.""" import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import arviz as az import matplotlib.pyplot as plt import numpy as np - +import numpy.typing as npt +import pytensor.tensor as pt from pytensor.tensor.var import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import pearsonr, norm +from scipy.stats import norm, pearsonr + +from .tree import Tree + +TensorLike = Union[npt.NDArray[np.float_], pt.TensorVariable] -def _sample_posterior(all_trees, X, rng, size=None, excluded=None): +def _sample_posterior( + all_trees: List[List[Tree]], + X: TensorLike, + rng: np.random.Generator, + size: Optional[Union[int, Tuple[int, ...]]] = None, + excluded: Optional[npt.NDArray[np.int_]] = None, +) -> npt.NDArray[np.float_]: """ Generate samples from the BART-posterior. @@ -20,13 +32,13 @@ def _sample_posterior(all_trees, X, rng, size=None, excluded=None): ---------- all_trees : list List of all trees sampled from a posterior - X : array-like + X : tensor-like A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for out-of-sample predictions. rng : NumPy RandomGenerator size : int or tuple Number of samples. - excluded : list + excluded : Optional[npt.NDArray[np.int_]] Indexes of the variables to exclude when computing predictions """ stacked_trees = all_trees @@ -34,12 +46,14 @@ def _sample_posterior(all_trees, X, rng, size=None, excluded=None): X = X.eval() if size is None: - size = () + size_iter: Union[List, Tuple] = () elif isinstance(size, int): - size = [size] + size_iter = [size] + else: + size_iter = size flatten_size = 1 - for s in size: + for s in size_iter: flatten_size *= s idx = rng.integers(0, len(stacked_trees), size=flatten_size) @@ -50,11 +64,17 @@ def _sample_posterior(all_trees, X, rng, size=None, excluded=None): for ind, p in enumerate(pred): for tree in stacked_trees[idx[ind]]: p += np.vstack([tree.predict(x, excluded) for x in X]) - pred.reshape((*size, shape, -1)) + pred.reshape((*size_iter, shape, -1)) return pred -def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): +def plot_convergence( + idata: az.InferenceData, + var_name: Optional[str] = None, + kind: str = "ecdf", + figsize: Optional[Tuple[float, float]] = None, + ax=None, +) -> List[plt.Axes]: """ Plot convergence diagnostics. @@ -62,20 +82,20 @@ def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): ---------- idata : InferenceData InferenceData object containing the posterior samples. - var_name : str + var_name : Optional[str] Name of the BART variable to plot. Defaults to None. kind : str Type of plot to display. Options are "ecdf" (default) and "kde". - figsize : tuple + figsize : Optional[Tuple[float, float]], by default None. Figure size. Defaults to None. ax : matplotlib axes Axes on which to plot. Defaults to None. Returns ------- - ax : matplotlib axes + List[ax] : matplotlib axes """ - ess_threshold = idata.posterior.chain.size * 100 + ess_threshold = idata["posterior"]["chain"].size * 100 ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values) rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values) @@ -83,7 +103,7 @@ def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): figsize = (10, 3) if kind == "ecdf": - kind_func = az.plot_ecdf + kind_func: Callable[..., Any] = az.plot_ecdf sharey = True elif kind == "kde": kind_func = az.plot_kde @@ -111,29 +131,28 @@ def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): def plot_dependence( - bartrv, - X, - Y=None, - kind="pdp", - xs_interval="linear", - xs_values=None, - var_idx=None, - var_discrete=None, - func=None, - samples=50, - instances=10, - random_seed=None, - sharey=True, - rug=True, - smooth=True, - indices=None, - grid="long", + bartrv: Variable, + X: npt.NDArray[np.float_], + Y: Optional[npt.NDArray[np.float_]] = None, + kind: str = "pdp", + xs_interval: str = "linear", + xs_values: Optional[Union[int, List[float]]] = None, + var_idx: Optional[List[int]] = None, + var_discrete: Optional[List[int]] = None, + func: Optional[Callable] = None, + samples: int = 50, + instances: int = 10, + random_seed: Optional[int] = None, + sharey: bool = True, + rug: bool = True, + smooth: bool = True, + grid: str = "long", color="C0", - color_mean="C0", - alpha=0.1, - figsize=None, - smooth_kwargs=None, - ax=None, + color_mean: str = "C0", + alpha: float = 0.1, + figsize: Optional[Tuple[float, float]] = None, + smooth_kwargs: Optional[Dict[str, Any]] = None, + ax: Optional[plt.Axes] = None, ): """ Partial dependence or individual conditional expectation plot. @@ -142,9 +161,9 @@ def plot_dependence( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : array-like + X : npt.NDArray[np.float_] The covariate matrix. - Y : array-like + Y : Optional[npt.NDArray[np.float_]], by default None. The response vector. kind : str Whether to plor a partial dependence plot ("pdp") or an individual conditional expectation @@ -154,22 +173,22 @@ def plot_dependence( evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified quantiles of X. "insample", the evaluation is done at the values of X. For discrete variables these options are ommited. - xs_values : int or list + xs_values : Optional[Union[int, List[float]]], by default None. Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. Ignored when ``xs_interval="insample"``. - var_idx : list + var_idx : Optional[List[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : list + var_discrete : Optional[List[int]], by default None. List of the indices of the covariate treated as discrete. - func : function + func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. samples : int Number of posterior samples used in the predictions. Defaults to 50 instances : int Number of instances of X to plot. Only relevant if ice ``kind="ice"`` plots. - random_seed : int + random_seed : Optional[int], by default None. Seed used to sample from the posterior. Defaults to None. sharey : bool Controls sharing of properties among y-axes. Defaults to True. @@ -219,7 +238,7 @@ def plot_dependence( else: x_names = [] - if hasattr(Y, "name"): + if Y is not None and hasattr(Y, "name"): y_label = f"Predicted {Y.name}" else: y_label = "Predicted Y" @@ -247,7 +266,7 @@ def plot_dependence( xs_values = [0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95] if kind == "ice": - instances = rng.choice(range(X.shape[0]), replace=False, size=instances) + instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances) new_y = [] new_x_target = [] @@ -264,9 +283,9 @@ def plot_dependence( if i in var_discrete: new_x_i = np.unique(X[:, i]) else: - if xs_interval == "linear": + if xs_interval == "linear" and isinstance(xs_values, int): new_x_i = np.linspace(np.nanmin(X[:, i]), np.nanmax(X[:, i]), xs_values) - elif xs_interval == "quantiles": + elif xs_interval == "quantiles" and isinstance(xs_values, list): new_x_i = np.quantile(X[:, i], q=xs_values) elif xs_interval == "insample": new_x_i = X[:, i] @@ -279,7 +298,7 @@ def plot_dependence( ) new_x_target.append(new_x_i) else: - for instance in instances: + for instance in instances_ary: new_X = X[idx_s] new_X[:, indices_mi] = X[:, indices_mi][instance] y_pred.append( @@ -292,9 +311,9 @@ def plot_dependence( if func is not None: new_y = [func(nyi) for nyi in new_y] shape = 1 - new_y = np.array(new_y) - if new_y[0].ndim == 3: - shape = new_y[0].shape[0] + new_y_ary = np.array(new_y) + if new_y_ary[0].ndim == 3: + shape = new_y_ary[0].shape[0] if ax is None: if grid == "long": fig, axes = plt.subplots(len(var_idx) * shape, sharey=sharey, figsize=figsize) @@ -303,7 +322,7 @@ def plot_dependence( elif isinstance(grid, tuple): fig, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize) grid_size = grid[0] * grid[1] - n_plots = new_y.squeeze().shape[0] + n_plots = new_y_ary.squeeze().shape[0] if n_plots > grid_size: warnings.warn("The grid is smaller than the number of available variables to plot") elif n_plots < grid_size: @@ -318,7 +337,7 @@ def plot_dependence( x_idx = 0 y_idx = 0 for ax in axes: # pylint: disable=redefined-argument-from-local - nyi = new_y[x_idx][y_idx] + nyi = new_y_ary[x_idx][y_idx] nxi = new_x_target[x_idx] var = var_idx[x_idx] @@ -388,8 +407,15 @@ def plot_dependence( def plot_variable_importance( - idata, bartrv, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None -): + idata: az.InferenceData, + bartrv: Variable, + X: npt.NDArray[np.float_], + labels: Optional[List[str]] = None, + sort_vars: bool = True, + figsize: Optional[Tuple[float, float]] = None, + samples: int = 100, + random_seed: Optional[int] = None, +) -> Tuple[npt.NDArray[np.int_], List[plt.axes]]: """ Estimates variable importance from the BART-posterior. @@ -399,9 +425,9 @@ def plot_variable_importance( InferenceData containing a collection of BART_trees in sample_stats group bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : array-like + X : npt.NDArray[np.float_] The covariate matrix. - labels : list + labels : Optional[List[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. sort_vars : bool @@ -410,7 +436,7 @@ def plot_variable_importance( Figure size. If None it will be defined automatically. samples : int Number of predictions used to compute correlation for subsets of variables. Defaults to 100 - random_seed : int + random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. Returns ------- @@ -423,18 +449,18 @@ def plot_variable_importance( labels = X.columns X = X.values - var_imp = idata.sample_stats["variable_inclusion"].mean(("chain", "draw")).values + var_imp = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values if labels is None: - labels = np.arange(len(var_imp)) + labels_ary = np.arange(len(var_imp)) else: - labels = np.array(labels) + labels_ary = np.array(labels) rng = np.random.default_rng(random_seed) ticks = np.arange(len(var_imp), dtype=int) idxs = np.argsort(var_imp) subsets = [idxs[:-i] for i in range(1, len(idxs))] - subsets.append(None) + subsets.append(None) # type: ignore if sort_vars: indices = idxs[::-1] @@ -442,7 +468,7 @@ def plot_variable_importance( indices = np.arange(len(var_imp)) axes[0].plot((var_imp / var_imp.sum())[indices], "o-") axes[0].set_xticks(ticks) - axes[0].set_xticklabels(labels[indices]) + axes[0].set_xticklabels(labels_ary[indices]) axes[0].set_xlabel("covariables") axes[0].set_ylabel("importance") @@ -453,7 +479,9 @@ def plot_variable_importance( ev_mean = np.zeros(len(var_imp)) ev_hdi = np.zeros((len(var_imp), 2)) for idx, subset in enumerate(subsets): - predicted_subset = _sample_posterior(all_trees, X=X, rng=rng, size=samples, excluded=subset) + predicted_subset = _sample_posterior( + all_trees=all_trees, X=X, rng=rng, size=samples, excluded=subset + ) pearson = np.zeros(samples) for j in range(samples): pearson[j] = ( diff --git a/pyproject.toml b/pyproject.toml index aa4e5e2..8f1bbef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,10 @@ [tool.pytest.ini_options] minversion = "6.0" xfail_strict=true +addopts = [ + "-vv", + "--color=yes", +] [tool.black] line-length = 100 diff --git a/requirements-dev.txt b/requirements-dev.txt index 172603d..e51a93c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,8 @@ black==22.3.0 click==8.0.4 +mypy>=1.1.1 +pandas-stubs==1.5.3.230304 +pre-commit pylint==2.10.2 pytest-cov>=2.6.1 pytest>=4.4.0 diff --git a/setup.py b/setup.py index b38a7ba..d990b7f 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ # limitations under the License. import re - from codecs import open from os.path import dirname, join, realpath diff --git a/tests/test_bart.py b/tests/test_bart.py index c7578d0..d8a3a27 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -3,12 +3,11 @@ import pytest from numpy.random import RandomState from numpy.testing import assert_almost_equal, assert_array_equal +from pymc.initial_point import make_initial_point_fn +from pymc.logprob.basic import joint_logp import pymc_bart as pmb -from pymc.logprob.joint_logprob import joint_logp -from pymc.initial_point import make_initial_point_fn - def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index 1227b7b..ec9fbac 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -1,8 +1,11 @@ from unittest import TestCase + import numpy as np import pymc as pm + import pymc_bart as pmb -from pymc_bart.pgbart import fast_mean, discrete_uniform_sampler, NormalSampler, UniformSampler +from pymc_bart.pgbart import (NormalSampler, UniformSampler, + discrete_uniform_sampler, fast_mean) class TestSystematic(TestCase):