From 7c1c5bb4bd75cf2ca1545fb99e2d05d98cb1e0f8 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 20 Mar 2023 22:14:13 +0100 Subject: [PATCH 01/26] mypy init --- .github/workflows/test.yml | 4 + .pre-commit-config.yaml | 6 ++ mypy.ini | 15 ++++ pymc_bart/__init__.py | 3 +- pymc_bart/bart.py | 5 +- pymc_bart/pgbart.py | 154 ++++++++++++++++++------------------- pymc_bart/tree.py | 147 +++++++++++++++++++++-------------- pymc_bart/utils.py | 3 +- requirements-dev.txt | 3 + setup.py | 1 - tests/test_bart.py | 5 +- tests/test_pgbart.py | 5 +- 12 files changed, 204 insertions(+), 147 deletions(-) create mode 100644 mypy.ini diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3b5a5df..db9f630 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -43,6 +43,10 @@ jobs: echo "Success!" echo "Checking code style with pylint..." python -m pylint pymc_bart/ + - name: Run Mypy + shell: bash -l {0} + run: | + python -m mypy pymc_bart - name: Run tests shell: bash -l {0} run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da409e7..5432a77 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,6 +13,12 @@ repos: language: system types: [python] files: ^pymc_bart/ + - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + files: ^pymc_bart/ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..56088d7 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,15 @@ +[mypy] +files = pymc_bart/*.py +plugins = numpy.typing.mypy_plugin + +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-numba.*] +ignore_missing_imports = True + +[mypy-pymc.*] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index da62697..1e20653 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -15,7 +15,8 @@ from pymc_bart.bart import BART from pymc_bart.pgbart import PGBART -from pymc_bart.utils import plot_convergence, plot_dependence, plot_variable_importance +from pymc_bart.utils import (plot_convergence, plot_dependence, + plot_variable_importance) __all__ = ["BART", "PGBART"] __version__ = "0.4.0" diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index dad53d6..ac0bcd8 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -15,15 +15,14 @@ # limitations under the License. from multiprocessing import Manager + import numpy as np +import pytensor.tensor as pt from pandas import DataFrame, Series - from pymc.distributions.distribution import Distribution, _moment from pymc.logprob.abstract import _logprob -import pytensor.tensor as pt from pytensor.tensor.random.op import RandomVariable - from .utils import _sample_posterior __all__ = ["BART"] diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 6755836..f4473ed 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,21 +12,76 @@ # See the License for the specific language governing permissions and # limitations under the License. -from numba import njit +from typing import List, Optional, Union import numpy as np - -from pytensor import function as pytensor_function -from pytensor import config -from pytensor.tensor.var import Variable - -from pymc.model import modelcontext +from numba import njit +from pymc.model import Model, modelcontext +from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared from pymc.step_methods.compound import Competence -from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements +from pytensor import config +from pytensor import function as pytensor_function +from pytensor.tensor.var import Variable from pymc_bart.bart import BARTRV -from pymc_bart.tree import Tree, Node, get_depth +from pymc_bart.tree import Node, Tree, get_depth + + +class ParticleTree: + """Particle tree.""" + + __slots__ = "tree", "expansion_nodes", "log_weight", "kfactor" + + def __init__(self, tree: Tree, kfactor: float = 0.75): + self.tree: Tree = tree.copy() + self.expansion_nodes: List[int] = [0] + self.log_weight: float = 0 + self.kfactor: float = kfactor + + def copy(self) -> "ParticleTree": + p = ParticleTree(self.tree) + p.expansion_nodes = self.expansion_nodes.copy() + p.kfactor = self.kfactor + return p + + def sample_tree( + self, + ssv, + available_predictors, + prior_prob_leaf_node, + X, + missing_data, + sum_trees, + m, + normal, + shape, + ) -> bool: + tree_grew = False + if self.expansion_nodes: + index_leaf_node = self.expansion_nodes.pop(0) + # Probability that this node will remain a leaf node + prob_leaf = prior_prob_leaf_node[get_depth(index_leaf_node)] + + if prob_leaf < np.random.random(): + idx_new_nodes = grow_tree( + self.tree, + index_leaf_node, + ssv, + available_predictors, + X, + missing_data, + sum_trees, + m, + normal, + self.kfactor, + shape, + ) + if idx_new_nodes is not None: + self.expansion_nodes.extend(idx_new_nodes) + tree_grew = True + + return tree_grew class PGBART(ArrayStepShared): @@ -55,9 +110,9 @@ class PGBART(ArrayStepShared): def __init__( self, vars=None, # pylint: disable=redefined-builtin - num_particles=20, - batch="auto", - model=None, + num_particles: int = 20, + batch: Union[str, int] = "auto", + model: Optional[Model] = None, ): model = modelcontext(model) initial_values = model.initial_point() @@ -205,7 +260,7 @@ def astep(self, _): stats = {"variable_inclusion": variable_inclusion, "tune": self.tune} return self.sum_trees, [stats] - def normalize(self, particles): + def normalize(self, particles) -> float: """ Use softmax to get normalized_weights. """ @@ -215,7 +270,7 @@ def normalize(self, particles): wei = np.exp(log_w_) + 1e-12 return wei / wei.sum() - def resample(self, particles, normalized_weights): + def resample(self, particles: List[ParticleTree], normalized_weights) -> List[ParticleTree]: """ Use systematic resample for all but the first particle @@ -258,18 +313,17 @@ def systematic(self, normalized_weights): single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw return inverse_cdf(single_uniform, normalized_weights) - def init_particles(self, tree_id: int) -> np.ndarray: + def init_particles(self, tree_id: int) -> List[ParticleTree]: """Initialize particles.""" - p0 = self.all_particles[tree_id] + p0: ParticleTree = self.all_particles[tree_id] # The old tree does not grow so we update the weight only once self.update_weight(p0) - particles = [p0] - - for _ in self.indices: - particles.append( - ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor) - ) + particles: List[ParticleTree] = [p0] + particles.extend( + ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor) + for _ in self.indices + ) return particles def update_weight(self, particle): @@ -290,62 +344,6 @@ def competence(var, has_grad): return Competence.INCOMPATIBLE -class ParticleTree: - """Particle tree.""" - - __slots__ = "tree", "expansion_nodes", "log_weight", "kfactor" - - def __init__(self, tree, kfactor=0.75): - self.tree = tree.copy() - self.expansion_nodes = [0] - self.log_weight = 0 - self.kfactor = kfactor - - def copy(self): - p = ParticleTree(self.tree) - p.expansion_nodes = self.expansion_nodes.copy() - p.kfactor = self.kfactor - return p - - def sample_tree( - self, - ssv, - available_predictors, - prior_prob_leaf_node, - X, - missing_data, - sum_trees, - m, - normal, - shape, - ): - tree_grew = False - if self.expansion_nodes: - index_leaf_node = self.expansion_nodes.pop(0) - # Probability that this node will remain a leaf node - prob_leaf = prior_prob_leaf_node[get_depth(index_leaf_node)] - - if prob_leaf < np.random.random(): - idx_new_nodes = grow_tree( - self.tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees, - m, - normal, - self.kfactor, - shape, - ) - if idx_new_nodes is not None: - self.expansion_nodes.extend(idx_new_nodes) - tree_grew = True - - return tree_grew - - class SampleSplittingVariable: def __init__(self, alpha_vec): """ diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 4f98709..7e494f6 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -13,11 +13,55 @@ # limitations under the License. import math - from functools import lru_cache +from typing import Dict, Generator, List, Optional -from pytensor import config import numpy as np +import numpy.typing as npt +from pytensor import config + + +class Node: + __slots__ = "index", "value", "idx_split_variable", "idx_data_points" + + def __init__( + self, + index: int, + value: float = -1.0, + idx_data_points: Optional[List[int]] = None, + idx_split_variable: int = -1, + ) -> None: + self.index = index + self.value = value + self.idx_data_points = idx_data_points + self.idx_split_variable = idx_split_variable + + @classmethod + def new_leaf_node( + cls, index: int, value: float, idx_data_points: Optional[List[int]] + ) -> "Node": + return cls(index, value=value, idx_data_points=idx_data_points) + + @classmethod + def new_split_node(cls, index: int, split_value: float, idx_split_variable: int) -> "Node": + return cls(index=index, value=split_value, idx_split_variable=idx_split_variable) + + def get_idx_left_child(self) -> int: + return self.index * 2 + 1 + + def get_idx_right_child(self) -> int: + return self.index * 2 + 2 + + def is_split_node(self) -> bool: + return self.idx_split_variable >= 0 + + def is_leaf_node(self) -> bool: + return not self.is_split_node() + + +@lru_cache +def get_depth(index: int) -> int: + return math.floor(math.log2(index + 1)) class Tree: @@ -52,66 +96,88 @@ class Tree: "output", ) - def __init__(self, tree_structure, idx_leaf_nodes, output): + def __init__( + self, + tree_structure: Dict[int, Node], + idx_leaf_nodes: Optional[List[int]], + output: Optional[npt.NDArray], + ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes self.output = output @classmethod - def new_tree(cls, leaf_node_value, idx_data_points, num_observations, shape): + def new_tree( + cls, + leaf_node_value: float, + idx_data_points: Optional[List[int]], + num_observations: int, + shape: int, + ) -> "Tree": return cls( tree_structure={ - 0: Node.new_leaf_node(0, value=leaf_node_value, idx_data_points=idx_data_points) + 0: Node.new_leaf_node( + index=0, value=leaf_node_value, idx_data_points=idx_data_points + ) }, idx_leaf_nodes=[0], output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(), ) - def __getitem__(self, index): + def __getitem__(self, index) -> Node: return self.get_node(index) - def __setitem__(self, index, node): + def __setitem__(self, index, node) -> None: self.set_node(index, node) - def copy(self): - tree = { + def copy(self) -> "Tree": + tree: Dict[int, Node] = { k: Node(v.index, v.value, v.idx_data_points, v.idx_split_variable) for k, v in self.tree_structure.items() } - return Tree(tree, self.idx_leaf_nodes.copy(), self.output.copy()) + idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes else None + output = self.output.copy() if self.output is not None else None + return Tree(tree_structure=tree, idx_leaf_nodes=idx_leaf_nodes, output=output) - def get_node(self, index) -> "Node": + def get_node(self, index: int) -> Node: return self.tree_structure[index] - def set_node(self, index, node): + def set_node(self, index: int, node: Node) -> None: self.tree_structure[index] = node - if node.is_leaf_node(): + if node.is_leaf_node() and self.idx_leaf_nodes is not None: self.idx_leaf_nodes.append(index) - def grow_leaf_node(self, current_node, selected_predictor, split_value, index_leaf_node): + def grow_leaf_node(self, current_node: Node, selected_predictor, split_value, index_leaf_node): current_node.value = split_value current_node.idx_split_variable = selected_predictor current_node.idx_data_points = None - self.idx_leaf_nodes.remove(index_leaf_node) + if self.idx_leaf_nodes is not None: + self.idx_leaf_nodes.remove(index_leaf_node) - def trim(self): - tree = { + def trim(self) -> "Tree": + tree: Dict[int, Node] = { k: Node(v.index, v.value, None, v.idx_split_variable) for k, v in self.tree_structure.items() } return Tree(tree, None, None) - def get_split_variables(self): + def get_split_variables(self) -> Generator[int, None, None]: for node in self.tree_structure.values(): if node.is_split_node(): yield node.idx_split_variable - def _predict(self): + def _predict(self) -> Optional[npt.NDArray]: output = self.output - for node_index in self.idx_leaf_nodes: - leaf_node = self.get_node(node_index) - output[leaf_node.idx_data_points] = leaf_node.value - return output.T + + if output is None: + return None # ? What do we return here? + + else: + if self.idx_leaf_nodes is not None: + for node_index in self.idx_leaf_nodes: + leaf_node = self.get_node(node_index) + output[leaf_node.idx_data_points] = leaf_node.value + return output.T if output is not None else None def predict(self, x, excluded=None): """ @@ -178,38 +244,3 @@ def _traverse_leaf_values(self, leaf_values, node_index): else: self._traverse_leaf_values(leaf_values, node.get_idx_left_child()) self._traverse_leaf_values(leaf_values, node.get_idx_right_child()) - - -class Node: - __slots__ = "index", "value", "idx_split_variable", "idx_data_points" - - def __init__(self, index: int, value=-1, idx_data_points=None, idx_split_variable=-1): - self.index = index - self.value = value - self.idx_data_points = idx_data_points - self.idx_split_variable = idx_split_variable - - @classmethod - def new_leaf_node(cls, index: int, value, idx_data_points) -> "Node": - return cls(index, value=value, idx_data_points=idx_data_points) - - @classmethod - def new_split_node(cls, index: int, split_value, idx_split_variable) -> "Node": - return cls(index, value=split_value, idx_split_variable=idx_split_variable) - - def get_idx_left_child(self) -> int: - return self.index * 2 + 1 - - def get_idx_right_child(self) -> int: - return self.index * 2 + 2 - - def is_split_node(self) -> bool: - return self.idx_split_variable >= 0 - - def is_leaf_node(self) -> bool: - return not self.is_split_node() - - -@lru_cache -def get_depth(index: int) -> int: - return math.floor(math.log2(index + 1)) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index a693955..b1a22bc 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -5,11 +5,10 @@ import arviz as az import matplotlib.pyplot as plt import numpy as np - from pytensor.tensor.var import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import pearsonr, norm +from scipy.stats import norm, pearsonr def _sample_posterior(all_trees, X, rng, size=None, excluded=None): diff --git a/requirements-dev.txt b/requirements-dev.txt index 172603d..cc34345 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,8 @@ black==22.3.0 click==8.0.4 +mypy==1.1.1 +pandas-stubs==1.5.3.230304 +pre-commit pylint==2.10.2 pytest-cov>=2.6.1 pytest>=4.4.0 diff --git a/setup.py b/setup.py index b38a7ba..d990b7f 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ # limitations under the License. import re - from codecs import open from os.path import dirname, join, realpath diff --git a/tests/test_bart.py b/tests/test_bart.py index c7578d0..6e47b99 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -3,12 +3,11 @@ import pytest from numpy.random import RandomState from numpy.testing import assert_almost_equal, assert_array_equal +from pymc.initial_point import make_initial_point_fn +from pymc.logprob.joint_logprob import joint_logp import pymc_bart as pmb -from pymc.logprob.joint_logprob import joint_logp -from pymc.initial_point import make_initial_point_fn - def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index 1227b7b..ec9fbac 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -1,8 +1,11 @@ from unittest import TestCase + import numpy as np import pymc as pm + import pymc_bart as pmb -from pymc_bart.pgbart import fast_mean, discrete_uniform_sampler, NormalSampler, UniformSampler +from pymc_bart.pgbart import (NormalSampler, UniformSampler, + discrete_uniform_sampler, fast_mean) class TestSystematic(TestCase): From 6822cead1f05cc5e5f46b77f075aaba04de3ce24 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 20 Mar 2023 22:29:35 +0100 Subject: [PATCH 02/26] undo change --- pymc_bart/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 1e20653..da62697 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -15,8 +15,7 @@ from pymc_bart.bart import BART from pymc_bart.pgbart import PGBART -from pymc_bart.utils import (plot_convergence, plot_dependence, - plot_variable_importance) +from pymc_bart.utils import plot_convergence, plot_dependence, plot_variable_importance __all__ = ["BART", "PGBART"] __version__ = "0.4.0" From d1313446eb6b9d9e592ff41611032994fb45dfd3 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 21 Mar 2023 21:37:58 +0100 Subject: [PATCH 03/26] fix tree.py types --- pymc_bart/pgbart.py | 4 ++-- pymc_bart/tree.py | 41 ++++++++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index f4473ed..582b732 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -420,7 +420,7 @@ def grow_tree( current_node.get_idx_right_child(), ) - new_nodes = [] + new_nodes = np.array([]) for idx in range(2): idx_data_point = new_idx_data_points[idx] node_value = draw_leaf_value( @@ -435,7 +435,7 @@ def grow_tree( value=node_value, idx_data_points=idx_data_point, ) - new_nodes.append(new_node) + new_nodes = np.append(new_nodes, new_node) tree.grow_leaf_node(current_node, selected_predictor, split_value, index_leaf_node) tree.set_node(new_nodes[0].index, new_nodes[0]) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 7e494f6..359a01b 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -28,7 +28,7 @@ def __init__( self, index: int, value: float = -1.0, - idx_data_points: Optional[List[int]] = None, + idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, ) -> None: self.index = index @@ -38,7 +38,7 @@ def __init__( @classmethod def new_leaf_node( - cls, index: int, value: float, idx_data_points: Optional[List[int]] + cls, index: int, value: float, idx_data_points: Optional[npt.NDArray[np.int_]] ) -> "Node": return cls(index, value=value, idx_data_points=idx_data_points) @@ -99,8 +99,8 @@ class Tree: def __init__( self, tree_structure: Dict[int, Node], - idx_leaf_nodes: Optional[List[int]], - output: Optional[npt.NDArray], + idx_leaf_nodes: Optional[npt.NDArray[np.int_]], + output: Optional[npt.NDArray[np.float_]], ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes @@ -110,7 +110,7 @@ def __init__( def new_tree( cls, leaf_node_value: float, - idx_data_points: Optional[List[int]], + idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, ) -> "Tree": @@ -120,7 +120,7 @@ def new_tree( index=0, value=leaf_node_value, idx_data_points=idx_data_points ) }, - idx_leaf_nodes=[0], + idx_leaf_nodes=np.array([0]), output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(), ) @@ -135,7 +135,7 @@ def copy(self) -> "Tree": k: Node(v.index, v.value, v.idx_data_points, v.idx_split_variable) for k, v in self.tree_structure.items() } - idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes else None + idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None output = self.output.copy() if self.output is not None else None return Tree(tree_structure=tree, idx_leaf_nodes=idx_leaf_nodes, output=output) @@ -145,14 +145,18 @@ def get_node(self, index: int) -> Node: def set_node(self, index: int, node: Node) -> None: self.tree_structure[index] = node if node.is_leaf_node() and self.idx_leaf_nodes is not None: - self.idx_leaf_nodes.append(index) + # self.idx_leaf_nodes.append(index) + self.idx_leaf_nodes = np.append(self.idx_leaf_nodes, index) - def grow_leaf_node(self, current_node: Node, selected_predictor, split_value, index_leaf_node): + def grow_leaf_node( + self, current_node: Node, selected_predictor: int, split_value: float, index_leaf_node: int + ) -> None: current_node.value = split_value current_node.idx_split_variable = selected_predictor current_node.idx_data_points = None if self.idx_leaf_nodes is not None: - self.idx_leaf_nodes.remove(index_leaf_node) + # self.idx_leaf_nodes.remove(index_leaf_node) + self.idx_leaf_nodes = np.setdiff1d(self.idx_leaf_nodes, index_leaf_node) def trim(self) -> "Tree": tree: Dict[int, Node] = { @@ -166,7 +170,7 @@ def get_split_variables(self) -> Generator[int, None, None]: if node.is_split_node(): yield node.idx_split_variable - def _predict(self) -> Optional[npt.NDArray]: + def _predict(self) -> Optional[npt.NDArray[np.float_]]: output = self.output if output is None: @@ -179,7 +183,7 @@ def _predict(self) -> Optional[npt.NDArray]: output[leaf_node.idx_data_points] = leaf_node.value return output.T if output is not None else None - def predict(self, x, excluded=None): + def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None) -> float: """ Predict output of tree for an (un)observed point x. @@ -199,7 +203,7 @@ def predict(self, x, excluded=None): excluded = [] return self._traverse_tree(x, 0, excluded) - def _traverse_tree(self, x, node_index, excluded): + def _traverse_tree(self, x, node_index: int, excluded: Optional[List[int]] = None) -> float: """ Traverse the tree starting from a particular node given an unobserved point. @@ -207,16 +211,18 @@ def _traverse_tree(self, x, node_index, excluded): ---------- x : np.ndarray node_index : int + excluded: list Returns ------- Leaf node value or mean of leaf node values """ - current_node = self.get_node(node_index) + current_node: Node = self.get_node(node_index) if current_node.is_leaf_node(): return current_node.value - if current_node.idx_split_variable in excluded: - leaf_values = [] + + if excluded is not None and current_node.idx_split_variable in excluded: + leaf_values: List[float] = [] self._traverse_leaf_values(leaf_values, node_index) return np.mean(leaf_values, 0) @@ -226,12 +232,13 @@ def _traverse_tree(self, x, node_index, excluded): next_node = current_node.get_idx_right_child() return self._traverse_tree(x, next_node, excluded) - def _traverse_leaf_values(self, leaf_values, node_index): + def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- + leaf_values : list node_index : int Returns From d889df6729fc25be6a98f48f355814f81fc4dcf0 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 23 Mar 2023 18:16:52 +0100 Subject: [PATCH 04/26] small code improvements --- pymc_bart/pgbart.py | 9 ++++--- pymc_bart/tree.py | 58 ++++++++++++++++++++++++++------------------- pyproject.toml | 4 ++++ 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 582b732..f3383d7 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import numpy as np from numba import njit @@ -184,7 +184,7 @@ def __init__( self.indices = list(range(1, num_particles)) shared = make_shared_replacements(initial_values, vars, model) self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared) - self.all_particles = list(ParticleTree(self.a_tree) for _ in range(self.m)) + self.all_particles = [ParticleTree(self.a_tree) for _ in range(self.m)] self.all_trees = np.array([p.tree for p in self.all_particles]) self.lower = 0 self.iter = 0 @@ -209,7 +209,7 @@ def astep(self, _): # Sample each particle (try to grow each tree), except for the first one stop_growing = True for p in particles[1:]: - tree_grew = p.sample_tree( + if p.sample_tree( self.ssv, self.available_predictors, self.prior_prob_leaf_node, @@ -219,8 +219,7 @@ def astep(self, _): self.m, self.normal, self.shape, - ) - if tree_grew: + ): self.update_weight(p) if p.expansion_nodes: stop_growing = False diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 359a01b..ab6233c 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -22,6 +22,16 @@ class Node: + """Node of a binary tree. + + Attributes + ---------- + index : int + value : float + idx_data_points : Optional[npt.NDArray[np.int_]] + idx_split_variable : Optional[npt.NDArray[np.int_]] + """ + __slots__ = "index", "value", "idx_split_variable", "idx_data_points" def __init__( @@ -72,21 +82,21 @@ class Tree: Attributes ---------- - tree_structure : dict + tree_structure : Dict[int, Node] A dictionary that represents the nodes stored in breadth-first order, based in the array method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). The dictionary's keys are integers that represent the nodes position. The dictionary's values are objects of type Node that represent the split and leaf nodes of the tree itself. - idx_leaf_nodes : list - List with the index of the leaf nodes of the tree. - output: array + idx_leaf_nodes : Optional[npt.NDArray[np.int_]] + Array with the index of the leaf nodes of the tree. + output: Optional[npt.NDArray[np.float_]] Array of shape number of observations, shape Parameters ---------- tree_structure : Dictionary of nodes - idx_leaf_nodes : List with the index of the leaf nodes of the tree. + idx_leaf_nodes : Array with the index of the leaf nodes of the tree. output : Array of shape number of observations, shape """ @@ -174,14 +184,13 @@ def _predict(self) -> Optional[npt.NDArray[np.float_]]: output = self.output if output is None: - return None # ? What do we return here? + return None - else: - if self.idx_leaf_nodes is not None: - for node_index in self.idx_leaf_nodes: - leaf_node = self.get_node(node_index) - output[leaf_node.idx_data_points] = leaf_node.value - return output.T if output is not None else None + if self.idx_leaf_nodes is not None: + for node_index in self.idx_leaf_nodes: + leaf_node = self.get_node(node_index) + output[leaf_node.idx_data_points] = leaf_node.value + return output.T if output is not None else None def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None) -> float: """ @@ -189,10 +198,10 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non Parameters ---------- - x : numpy array + x : npt.NDArray[np.float_] Unobserved point - excluded: list - Indexes of the variables to exclude when computing predictions + excluded: Optional[List[int]] + Indexes of the variables to exclude when computing predictions Returns ------- @@ -203,15 +212,20 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non excluded = [] return self._traverse_tree(x, 0, excluded) - def _traverse_tree(self, x, node_index: int, excluded: Optional[List[int]] = None) -> float: + def _traverse_tree( + self, x: npt.NDArray[np.float_], node_index: int, excluded: Optional[List[int]] = None + ) -> float: """ Traverse the tree starting from a particular node given an unobserved point. Parameters ---------- - x : np.ndarray + x : npt.NDArray[np.float_] + Unobserved point node_index : int - excluded: list + Index of the node to start the traversal from + excluded: Optional[List[int]] + Indexes of the variables to exclude when computing predictions Returns ------- @@ -224,7 +238,7 @@ def _traverse_tree(self, x, node_index: int, excluded: Optional[List[int]] = Non if excluded is not None and current_node.idx_split_variable in excluded: leaf_values: List[float] = [] self._traverse_leaf_values(leaf_values, node_index) - return np.mean(leaf_values, 0) + return np.mean(leaf_values, axis=0) if x[current_node.idx_split_variable] <= current_node.value: next_node = current_node.get_idx_left_child() @@ -238,12 +252,8 @@ def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> No Parameters ---------- - leaf_values : list + leaf_values : List[float] node_index : int - - Returns - ------- - List of leaf node values """ node = self.get_node(node_index) if node.is_leaf_node(): diff --git a/pyproject.toml b/pyproject.toml index aa4e5e2..8f1bbef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,10 @@ [tool.pytest.ini_options] minversion = "6.0" xfail_strict=true +addopts = [ + "-vv", + "--color=yes", +] [tool.black] line-length = 100 From b12feb528096f7fe097cbafe45f78669c4caddee Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 23 Mar 2023 18:26:14 +0100 Subject: [PATCH 05/26] fix batch type --- pymc_bart/pgbart.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index f3383d7..f6191a1 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple import numpy as np from numba import njit @@ -111,7 +111,7 @@ def __init__( self, vars=None, # pylint: disable=redefined-builtin num_particles: int = 20, - batch: Union[str, int] = "auto", + batch: Tuple[float, float] = (0.1, 0.1), model: Optional[Model] = None, ): model = modelcontext(model) @@ -171,14 +171,9 @@ def __init__( self.tune = True - if batch == "auto": - batch = max(1, int(self.m * 0.1)) - self.batch = (batch, batch) - else: - if isinstance(batch, (tuple, list)): - self.batch = batch - else: - self.batch = (batch, batch) + if batch == (0.1, 0.1): + batch_component = max(1, int(self.m * 0.1)) + self.batch = (batch_component, batch_component) self.num_particles = num_particles self.indices = list(range(1, num_particles)) From 6bb483da5fc48b5192a9bbca4c89afd83e185fa8 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 24 Mar 2023 16:13:40 +0100 Subject: [PATCH 06/26] add some more types --- pymc_bart/pgbart.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index f6191a1..71cb631 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import List, Optional, Tuple - +import numpy.typing as npt import numpy as np from numba import njit from pymc.model import Model, modelcontext @@ -132,10 +132,7 @@ def __init__( self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m shape = initial_values[value_bart.name].shape - if len(shape) == 1: - self.shape = 1 - else: - self.shape = shape[0] + self.shape = 1 if len(shape) == 1 else shape[0] if self.bart.split_prior: self.alpha_vec = self.bart.split_prior @@ -254,7 +251,7 @@ def astep(self, _): stats = {"variable_inclusion": variable_inclusion, "tune": self.tune} return self.sum_trees, [stats] - def normalize(self, particles) -> float: + def normalize(self, particles: List[ParticleTree]) -> float: """ Use softmax to get normalized_weights. """ @@ -264,15 +261,17 @@ def normalize(self, particles) -> float: wei = np.exp(log_w_) + 1e-12 return wei / wei.sum() - def resample(self, particles: List[ParticleTree], normalized_weights) -> List[ParticleTree]: + def resample( + self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_] + ) -> List[ParticleTree]: """ Use systematic resample for all but the first particle Ensure particles are copied only if needed. """ new_indices = self.systematic(normalized_weights) + 1 - seen = [] - new_particles = [] + seen: List[int] = [] + new_particles: List[ParticleTree] = [] for idx in new_indices: if idx in seen: new_particles.append(particles[idx].copy()) @@ -284,7 +283,9 @@ def resample(self, particles: List[ParticleTree], normalized_weights) -> List[Pa return particles - def get_particle_tree(self, particles, normalized_weights): + def get_particle_tree( + self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_] + ) -> Tuple[ParticleTree, Tree]: """ Sample a new particle and associated tree """ @@ -295,7 +296,7 @@ def get_particle_tree(self, particles, normalized_weights): return new_particle, new_particle.tree - def systematic(self, normalized_weights): + def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[np.int_]: """ Systematic resampling. @@ -320,7 +321,7 @@ def init_particles(self, tree_id: int) -> List[ParticleTree]: ) return particles - def update_weight(self, particle): + def update_weight(self, particle: ParticleTree) -> None: """ Update the weight of a particle. """ @@ -552,17 +553,19 @@ def update(self): @njit -def inverse_cdf(single_uniform, normalized_weights): +def inverse_cdf( + single_uniform: npt.NDArray[np.float_], normalized_weights: npt.NDArray[np.float_] +) -> npt.NDArray[np.int_]: """ Inverse CDF algorithm for a finite distribution. Parameters ---------- - single_uniform: ndarray - ordered points in [0,1] + single_uniform: npt.NDArray[np.float_] + Ordered points in [0,1] - normalized_weights: ndarray - normalized weights + normalized_weights: npt.NDArray[np.float_]) + Normalized weights Returns ------- From 8697f4601983b04d919745a1a1390b47a14c3cff Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 24 Mar 2023 21:11:24 +0100 Subject: [PATCH 07/26] Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin --- pymc_bart/pgbart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 71cb631..4410d44 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -168,9 +168,9 @@ def __init__( self.tune = True - if batch == (0.1, 0.1): - batch_component = max(1, int(self.m * 0.1)) - self.batch = (batch_component, batch_component) + batch_0 = max(1, int(self.m * batch[0])) + batch_1 = max(1, int(self.m * batch[1])) + self.batch = ( batch_0, batch_1) self.num_particles = num_particles self.indices = list(range(1, num_particles)) From f94d73d69424d155080bea7e171f71b288ab09a0 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 24 Mar 2023 21:51:15 +0100 Subject: [PATCH 08/26] fix ident --- pymc_bart/pgbart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 4410d44..6e049c7 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -168,9 +168,9 @@ def __init__( self.tune = True - batch_0 = max(1, int(self.m * batch[0])) - batch_1 = max(1, int(self.m * batch[1])) - self.batch = ( batch_0, batch_1) + batch_0 = max(1, int(self.m * batch[0])) + batch_1 = max(1, int(self.m * batch[1])) + self.batch = (batch_0, batch_1) self.num_particles = num_particles self.indices = list(range(1, num_particles)) From 38436a98739a40c5bb6dda7ab041dd6f89f4bbca Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 24 Mar 2023 22:14:29 +0100 Subject: [PATCH 09/26] more hints --- pymc_bart/pgbart.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 6e049c7..a993e19 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy.typing as npt import numpy as np from numba import njit @@ -340,7 +340,7 @@ def competence(var, has_grad): class SampleSplittingVariable: - def __init__(self, alpha_vec): + def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None: """ Sample splitting variables proportional to `alpha_vec`. @@ -349,15 +349,15 @@ def __init__(self, alpha_vec): """ self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) - def rvs(self): - rnd = np.random.random() + def rvs(self) -> Union[int, Tuple[int, float]]: + rnd: float = np.random.random() for i, val in self.enu: if rnd <= val: return i return self.enu[-1] -def compute_prior_probability(alpha): +def compute_prior_probability(alpha) -> List[float]: """ Calculate the probability of the node being a leaf node (1 - p(being split node)). @@ -376,7 +376,7 @@ def compute_prior_probability(alpha): .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. arXiv, `link `__ """ - prior_leaf_prob = [0] + prior_leaf_prob: List[float] = [0] depth = 1 while prior_leaf_prob[-1] < 1: prior_leaf_prob.append(1 - alpha**depth) From 846b0b8fdc0843832ef45eaf800957e466442364 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 28 Mar 2023 21:35:05 +0200 Subject: [PATCH 10/26] add types utils and bart --- pymc_bart/bart.py | 37 ++++++++++++++++++++++--------------- pymc_bart/utils.py | 15 ++++++++++++++- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index ac0bcd8..5b18921 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -17,13 +17,18 @@ from multiprocessing import Manager import numpy as np +import numpy.typing as npt import pytensor.tensor as pt from pandas import DataFrame, Series from pymc.distributions.distribution import Distribution, _moment from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from typing import List, Optional, Tuple + + +from .utils import _sample_posterior, TensorLike + -from .utils import _sample_posterior __all__ = ["BART"] @@ -31,11 +36,11 @@ class BARTRV(RandomVariable): """Base class for BART.""" - name = "BART" + name: str = "BART" ndim_supp = 1 - ndims_params = [2, 1, 0, 0, 1] - dtype = "floatX" - _print_name = ("BART", "\\operatorname{BART}") + ndims_params: List[int] = [2, 1, 0, 0, 1] + dtype: str = "floatX" + _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") all_trees = None def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): @@ -63,16 +68,16 @@ class BART(Distribution): Parameters ---------- - X : array-like + X : TensorLike The covariate matrix. - Y : array-like + Y : TensorLike The response vector. m : int Number of trees alpha : float Control the prior probability over the depth of the trees. Even when it can takes values in the interval (0, 1), it is recommended to be in the interval (0, 0.5]. - split_prior : array-like + split_prior : Optional[List[float]], default None. Each element of split_prior should be in the [0, 1] interval and the elements should sum to 1. Otherwise they will be normalized. Defaults to 0, i.e. all covariates have the same prior probability to be selected. @@ -80,12 +85,12 @@ class BART(Distribution): def __new__( cls, - name, - X, - Y, - m=50, - alpha=0.25, - split_prior=None, + name: str, + X: TensorLike, + Y: TensorLike, + m: int = 50, + alpha: float = 0.25, + split_prior: Optional[List[float]] = None, **kwargs, ): manager = Manager() @@ -146,7 +151,9 @@ def get_moment(cls, rv, size, *rv_inputs): return mean -def preprocess_xy(X, Y): +def preprocess_xy( + X: TensorLike, Y: TensorLike +) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]: if isinstance(Y, (Series, DataFrame)): Y = Y.to_numpy() if isinstance(X, (Series, DataFrame)): diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index b1a22bc..406a97e 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -2,16 +2,29 @@ import warnings +import pytensor.tensor as pt import arviz as az import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt from pytensor.tensor.var import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter from scipy.stats import norm, pearsonr +from typing import List, Optional, Tuple, Union +from .tree import Tree -def _sample_posterior(all_trees, X, rng, size=None, excluded=None): +TensorLike = Union[npt.NDArray[np.float_], pt.TensorVariable] + + +def _sample_posterior( + all_trees: List[List[Tree]], + X: TensorLike, + rng: np.random.Generator, + size=Optional[Union[int, Tuple[int, ...]]], + excluded: Optional[List[int]] = None, +) -> npt.NDArray[np.float_]: """ Generate samples from the BART-posterior. From 372b581fd1c60ec71b4e399f23e79f6b30c37bee Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 28 Mar 2023 21:41:33 +0200 Subject: [PATCH 11/26] fix lint --- pymc_bart/bart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 5b18921..26f7722 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -29,7 +29,6 @@ from .utils import _sample_posterior, TensorLike - __all__ = ["BART"] From e7574774171bc2e77484d93e0b33cbda570f4849 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 28 Mar 2023 22:06:43 +0200 Subject: [PATCH 12/26] fix some hints --- pymc_bart/utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 406a97e..579f96e 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -2,11 +2,11 @@ import warnings -import pytensor.tensor as pt import arviz as az import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +import pytensor.tensor as pt from pytensor.tensor.var import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter @@ -22,7 +22,7 @@ def _sample_posterior( all_trees: List[List[Tree]], X: TensorLike, rng: np.random.Generator, - size=Optional[Union[int, Tuple[int, ...]]], + size: Optional[Union[int, Tuple[int, ...]]] = None, excluded: Optional[List[int]] = None, ) -> npt.NDArray[np.float_]: """ @@ -46,12 +46,14 @@ def _sample_posterior( X = X.eval() if size is None: - size = () + size_iter: Union[List, Tuple] = () elif isinstance(size, int): - size = [size] + size_iter = [size] + else: + size_iter = size flatten_size = 1 - for s in size: + for s in size_iter: flatten_size *= s idx = rng.integers(0, len(stacked_trees), size=flatten_size) @@ -62,7 +64,7 @@ def _sample_posterior( for ind, p in enumerate(pred): for tree in stacked_trees[idx[ind]]: p += np.vstack([tree.predict(x, excluded) for x in X]) - pred.reshape((*size, shape, -1)) + pred.reshape((*size_iter, shape, -1)) return pred From 071c8ccda4230fc491481cfc2d1f4cb2591b775c Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Tue, 28 Mar 2023 22:14:19 +0200 Subject: [PATCH 13/26] fix sort --- pymc_bart/bart.py | 6 ++---- pymc_bart/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 26f7722..e1fe77c 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -15,6 +15,7 @@ # limitations under the License. from multiprocessing import Manager +from typing import List, Optional, Tuple import numpy as np import numpy.typing as npt @@ -23,11 +24,8 @@ from pymc.distributions.distribution import Distribution, _moment from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable -from typing import List, Optional, Tuple - - -from .utils import _sample_posterior, TensorLike +from .utils import TensorLike, _sample_posterior __all__ = ["BART"] diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 579f96e..fbb764a 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1,6 +1,7 @@ """Utility function for variable selection and bart interpretability.""" import warnings +from typing import List, Optional, Tuple, Union import arviz as az import matplotlib.pyplot as plt @@ -11,9 +12,8 @@ from scipy.interpolate import griddata from scipy.signal import savgol_filter from scipy.stats import norm, pearsonr -from typing import List, Optional, Tuple, Union -from .tree import Tree +from .tree import Tree TensorLike = Union[npt.NDArray[np.float_], pt.TensorVariable] From 8d4fc232cd415b0aa9384d6c5bd506fd2c94518e Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 30 Mar 2023 21:18:43 +0200 Subject: [PATCH 14/26] remove outout None type --- pymc_bart/tree.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index ab6233c..e5e41d8 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -110,7 +110,7 @@ def __init__( self, tree_structure: Dict[int, Node], idx_leaf_nodes: Optional[npt.NDArray[np.int_]], - output: Optional[npt.NDArray[np.float_]], + output: npt.NDArray[np.float_], ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes @@ -173,24 +173,21 @@ def trim(self) -> "Tree": k: Node(v.index, v.value, None, v.idx_split_variable) for k, v in self.tree_structure.items() } - return Tree(tree, None, None) + return Tree(tree_structure=tree, idx_leaf_nodes=None, output=np.array([-1])) def get_split_variables(self) -> Generator[int, None, None]: for node in self.tree_structure.values(): if node.is_split_node(): yield node.idx_split_variable - def _predict(self) -> Optional[npt.NDArray[np.float_]]: + def _predict(self) -> npt.NDArray[np.float_]: output = self.output - if output is None: - return None - if self.idx_leaf_nodes is not None: for node_index in self.idx_leaf_nodes: leaf_node = self.get_node(node_index) output[leaf_node.idx_data_points] = leaf_node.value - return output.T if output is not None else None + return output.T def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None) -> float: """ @@ -244,7 +241,7 @@ def _traverse_tree( next_node = current_node.get_idx_left_child() else: next_node = current_node.get_idx_right_child() - return self._traverse_tree(x, next_node, excluded) + return self._traverse_tree(x=x, node_index=next_node, excluded=excluded) def _traverse_leaf_values(self, leaf_values: List[float], node_index: int) -> None: """ From 2a252a1cc879ccf7d18a9dadc2a9fd8c1dad4276 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 30 Mar 2023 21:25:17 +0200 Subject: [PATCH 15/26] fix predict output type --- pymc_bart/tree.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index e5e41d8..d19e15b 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -189,7 +189,9 @@ def _predict(self) -> npt.NDArray[np.float_]: output[leaf_node.idx_data_points] = leaf_node.value return output.T - def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None) -> float: + def predict( + self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None + ) -> npt.NDArray[np.float_]: """ Predict output of tree for an (un)observed point x. @@ -202,7 +204,7 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non Returns ------- - float + npt.NDArray[np.float_] Value of the leaf value where the unobserved point lies. """ if excluded is None: @@ -211,7 +213,7 @@ def predict(self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = Non def _traverse_tree( self, x: npt.NDArray[np.float_], node_index: int, excluded: Optional[List[int]] = None - ) -> float: + ) -> npt.NDArray[np.float_]: """ Traverse the tree starting from a particular node given an unobserved point. @@ -226,11 +228,12 @@ def _traverse_tree( Returns ------- - Leaf node value or mean of leaf node values + npt.NDArray[np.float_] + Leaf node value or mean of leaf node values """ current_node: Node = self.get_node(node_index) if current_node.is_leaf_node(): - return current_node.value + return np.array(current_node.value) if excluded is not None and current_node.idx_split_variable in excluded: leaf_values: List[float] = [] From ccd98b8682be9fa7dc7675dda49b7e9327bc449a Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 30 Mar 2023 21:49:28 +0200 Subject: [PATCH 16/26] fix module path to make it compatible with pymc5.2.0 --- tests/test_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_bart.py b/tests/test_bart.py index 6e47b99..d8a3a27 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -4,7 +4,7 @@ from numpy.random import RandomState from numpy.testing import assert_almost_equal, assert_array_equal from pymc.initial_point import make_initial_point_fn -from pymc.logprob.joint_logprob import joint_logp +from pymc.logprob.basic import joint_logp import pymc_bart as pmb From e6d91b1e391c8c3edee1fb4530144863b0374062 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 30 Mar 2023 23:15:47 +0200 Subject: [PATCH 17/26] type hints to plots --- pymc_bart/bart.py | 3 +- pymc_bart/tree.py | 13 +++-- pymc_bart/utils.py | 134 +++++++++++++++++++++++++-------------------- 3 files changed, 84 insertions(+), 66 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index e1fe77c..b7f688f 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -25,6 +25,7 @@ from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from .tree import Tree from .utils import TensorLike, _sample_posterior __all__ = ["BART"] @@ -38,7 +39,7 @@ class BARTRV(RandomVariable): ndims_params: List[int] = [2, 1, 0, 0, 1] dtype: str = "floatX" _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") - all_trees = None + all_trees = List[List[Tree]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): return dist_params[0].shape[:1] diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index d19e15b..4daa69d 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -190,7 +190,7 @@ def _predict(self) -> npt.NDArray[np.float_]: return output.T def predict( - self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None + self, x: npt.NDArray[np.float_], excluded: Optional[npt.NDArray[np.int_]] = None ) -> npt.NDArray[np.float_]: """ Predict output of tree for an (un)observed point x. @@ -199,7 +199,7 @@ def predict( ---------- x : npt.NDArray[np.float_] Unobserved point - excluded: Optional[List[int]] + excluded: Optional[npt.NDArray[np.int_]] Indexes of the variables to exclude when computing predictions Returns @@ -208,11 +208,14 @@ def predict( Value of the leaf value where the unobserved point lies. """ if excluded is None: - excluded = [] + excluded = np.array([]) return self._traverse_tree(x, 0, excluded) def _traverse_tree( - self, x: npt.NDArray[np.float_], node_index: int, excluded: Optional[List[int]] = None + self, + x: npt.NDArray[np.float_], + node_index: int, + excluded: Optional[npt.NDArray[np.int_]] = None, ) -> npt.NDArray[np.float_]: """ Traverse the tree starting from a particular node given an unobserved point. @@ -223,7 +226,7 @@ def _traverse_tree( Unobserved point node_index : int Index of the node to start the traversal from - excluded: Optional[List[int]] + excluded: Optional[npt.NDArray[np.int_]] Indexes of the variables to exclude when computing predictions Returns diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index fbb764a..fb67e55 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1,7 +1,7 @@ """Utility function for variable selection and bart interpretability.""" import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import arviz as az import matplotlib.pyplot as plt @@ -23,7 +23,7 @@ def _sample_posterior( X: TensorLike, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None, - excluded: Optional[List[int]] = None, + excluded: Optional[npt.NDArray[np.int_]] = None, ) -> npt.NDArray[np.float_]: """ Generate samples from the BART-posterior. @@ -32,13 +32,13 @@ def _sample_posterior( ---------- all_trees : list List of all trees sampled from a posterior - X : array-like + X : tensor-like A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for out-of-sample predictions. rng : NumPy RandomGenerator size : int or tuple Number of samples. - excluded : list + excluded : Optional[npt.NDArray[np.int_]] Indexes of the variables to exclude when computing predictions """ stacked_trees = all_trees @@ -68,7 +68,13 @@ def _sample_posterior( return pred -def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): +def plot_convergence( + idata: az.InferenceData, + var_name: Optional[str] = None, + kind: str = "ecdf", + figsize=Tuple[float, float], + ax=None, +) -> List[plt.Axes]: """ Plot convergence diagnostics. @@ -76,20 +82,20 @@ def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): ---------- idata : InferenceData InferenceData object containing the posterior samples. - var_name : str + var_name : Optional[str] Name of the BART variable to plot. Defaults to None. kind : str Type of plot to display. Options are "ecdf" (default) and "kde". - figsize : tuple + figsize : Tuple[float, float] Figure size. Defaults to None. ax : matplotlib axes Axes on which to plot. Defaults to None. Returns ------- - ax : matplotlib axes + List[ax] : matplotlib axes """ - ess_threshold = idata.posterior.chain.size * 100 + ess_threshold = idata["posterior"]["chain"].size * 100 ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values) rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values) @@ -97,7 +103,7 @@ def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): figsize = (10, 3) if kind == "ecdf": - kind_func = az.plot_ecdf + kind_func: Callable[..., Any] = az.plot_ecdf sharey = True elif kind == "kde": kind_func = az.plot_kde @@ -125,29 +131,28 @@ def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None): def plot_dependence( - bartrv, - X, - Y=None, - kind="pdp", - xs_interval="linear", - xs_values=None, - var_idx=None, - var_discrete=None, - func=None, - samples=50, - instances=10, - random_seed=None, - sharey=True, - rug=True, - smooth=True, - indices=None, - grid="long", + bartrv: Variable, + X: npt.NDArray[np.float_], + Y: Optional[npt.NDArray[np.float_]] = None, + kind: str = "pdp", + xs_interval: str = "linear", + xs_values: Optional[Union[int, List[float]]] = None, + var_idx: Optional[List[int]] = None, + var_discrete: Optional[List[int]] = None, + func: Optional[Callable] = None, + samples: int = 50, + instances: int = 10, + random_seed: Optional[int] = None, + sharey: bool = True, + rug: bool = True, + smooth: bool = True, + grid: str = "long", color="C0", - color_mean="C0", - alpha=0.1, - figsize=None, - smooth_kwargs=None, - ax=None, + color_mean: str = "C0", + alpha: float = 0.1, + figsize: Optional[Tuple[float, float]] = None, + smooth_kwargs: Optional[Dict[str, Any]] = None, + ax: Optional[plt.Axes] = None, ): """ Partial dependence or individual conditional expectation plot. @@ -156,9 +161,9 @@ def plot_dependence( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : array-like + X : npt.NDArray[np.float_] The covariate matrix. - Y : array-like + Y : Optional[npt.NDArray[np.float_]], by default None. The response vector. kind : str Whether to plor a partial dependence plot ("pdp") or an individual conditional expectation @@ -168,22 +173,22 @@ def plot_dependence( evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified quantiles of X. "insample", the evaluation is done at the values of X. For discrete variables these options are ommited. - xs_values : int or list + xs_values : Optional[Union[int, List[float]]], by default None. Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. Ignored when ``xs_interval="insample"``. - var_idx : list + var_idx : Optional[List[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : list + var_discrete : Optional[List[int]], by default None. List of the indices of the covariate treated as discrete. - func : function + func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. samples : int Number of posterior samples used in the predictions. Defaults to 50 instances : int Number of instances of X to plot. Only relevant if ice ``kind="ice"`` plots. - random_seed : int + random_seed : Optional[int], by default None. Seed used to sample from the posterior. Defaults to None. sharey : bool Controls sharing of properties among y-axes. Defaults to True. @@ -233,7 +238,7 @@ def plot_dependence( else: x_names = [] - if hasattr(Y, "name"): + if Y is not None and hasattr(Y, "name"): y_label = f"Predicted {Y.name}" else: y_label = "Predicted Y" @@ -261,7 +266,7 @@ def plot_dependence( xs_values = [0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95] if kind == "ice": - instances = rng.choice(range(X.shape[0]), replace=False, size=instances) + instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances) new_y = [] new_x_target = [] @@ -278,9 +283,9 @@ def plot_dependence( if i in var_discrete: new_x_i = np.unique(X[:, i]) else: - if xs_interval == "linear": + if xs_interval == "linear" and isinstance(xs_values, int): new_x_i = np.linspace(np.nanmin(X[:, i]), np.nanmax(X[:, i]), xs_values) - elif xs_interval == "quantiles": + elif xs_interval == "quantiles" and isinstance(xs_values, list): new_x_i = np.quantile(X[:, i], q=xs_values) elif xs_interval == "insample": new_x_i = X[:, i] @@ -293,7 +298,7 @@ def plot_dependence( ) new_x_target.append(new_x_i) else: - for instance in instances: + for instance in instances_ary: new_X = X[idx_s] new_X[:, indices_mi] = X[:, indices_mi][instance] y_pred.append( @@ -306,9 +311,9 @@ def plot_dependence( if func is not None: new_y = [func(nyi) for nyi in new_y] shape = 1 - new_y = np.array(new_y) - if new_y[0].ndim == 3: - shape = new_y[0].shape[0] + new_y_ary = np.array(new_y) + if new_y_ary[0].ndim == 3: + shape = new_y_ary[0].shape[0] if ax is None: if grid == "long": fig, axes = plt.subplots(len(var_idx) * shape, sharey=sharey, figsize=figsize) @@ -317,7 +322,7 @@ def plot_dependence( elif isinstance(grid, tuple): fig, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize) grid_size = grid[0] * grid[1] - n_plots = new_y.squeeze().shape[0] + n_plots = new_y_ary.squeeze().shape[0] if n_plots > grid_size: warnings.warn("The grid is smaller than the number of available variables to plot") elif n_plots < grid_size: @@ -332,7 +337,7 @@ def plot_dependence( x_idx = 0 y_idx = 0 for ax in axes: # pylint: disable=redefined-argument-from-local - nyi = new_y[x_idx][y_idx] + nyi = new_y_ary[x_idx][y_idx] nxi = new_x_target[x_idx] var = var_idx[x_idx] @@ -402,8 +407,15 @@ def plot_dependence( def plot_variable_importance( - idata, bartrv, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None -): + idata: az.InferenceData, + bartrv: Variable, + X: npt.NDArray[np.float_], + labels: Optional[List[str]] = None, + sort_vars: bool = True, + figsize: Optional[Tuple[float, float]] = None, + samples: int = 100, + random_seed: Optional[int] = None, +) -> Tuple[npt.NDArray[np.int_], List[plt.axes]]: """ Estimates variable importance from the BART-posterior. @@ -413,9 +425,9 @@ def plot_variable_importance( InferenceData containing a collection of BART_trees in sample_stats group bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : array-like + X : npt.NDArray[np.float_] The covariate matrix. - labels : list + labels : Optional[List[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. sort_vars : bool @@ -424,7 +436,7 @@ def plot_variable_importance( Figure size. If None it will be defined automatically. samples : int Number of predictions used to compute correlation for subsets of variables. Defaults to 100 - random_seed : int + random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. Returns ------- @@ -437,18 +449,18 @@ def plot_variable_importance( labels = X.columns X = X.values - var_imp = idata.sample_stats["variable_inclusion"].mean(("chain", "draw")).values + var_imp = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values if labels is None: - labels = np.arange(len(var_imp)) + labels_ary = np.arange(len(var_imp)) else: - labels = np.array(labels) + labels_ary = np.array(labels) rng = np.random.default_rng(random_seed) ticks = np.arange(len(var_imp), dtype=int) idxs = np.argsort(var_imp) subsets = [idxs[:-i] for i in range(1, len(idxs))] - subsets.append(None) + subsets.append(None) # type: ignore if sort_vars: indices = idxs[::-1] @@ -456,7 +468,7 @@ def plot_variable_importance( indices = np.arange(len(var_imp)) axes[0].plot((var_imp / var_imp.sum())[indices], "o-") axes[0].set_xticks(ticks) - axes[0].set_xticklabels(labels[indices]) + axes[0].set_xticklabels(labels_ary[indices]) axes[0].set_xlabel("covariables") axes[0].set_ylabel("importance") @@ -467,7 +479,9 @@ def plot_variable_importance( ev_mean = np.zeros(len(var_imp)) ev_hdi = np.zeros((len(var_imp), 2)) for idx, subset in enumerate(subsets): - predicted_subset = _sample_posterior(all_trees, X=X, rng=rng, size=samples, excluded=subset) + predicted_subset = _sample_posterior( + all_trees=all_trees, X=X, rng=rng, size=samples, excluded=subset + ) pearson = np.zeros(samples) for j in range(samples): pearson[j] = ( From 0b2c6269c2ec140cdd067afc68f43eaa836d2d82 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 31 Mar 2023 09:48:42 +0200 Subject: [PATCH 18/26] unpin mypy specifict version --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index cc34345..e51a93c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ black==22.3.0 click==8.0.4 -mypy==1.1.1 +mypy>=1.1.1 pandas-stubs==1.5.3.230304 pre-commit pylint==2.10.2 From b8e75fab6789a0c6aa9c001e6d7ca3e73c1328e5 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 31 Mar 2023 16:35:29 +0200 Subject: [PATCH 19/26] Update pymc_bart/utils.py Co-authored-by: Osvaldo A Martin --- pymc_bart/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index fb67e55..f5adfce 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -72,7 +72,7 @@ def plot_convergence( idata: az.InferenceData, var_name: Optional[str] = None, kind: str = "ecdf", - figsize=Tuple[float, float], + figsize: Tuple[float, float] = None, ax=None, ) -> List[plt.Axes]: """ From b45474479866f14ac4a861f79a39c6631a1f44d8 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 31 Mar 2023 16:36:26 +0200 Subject: [PATCH 20/26] Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index a993e19..288ccfe 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -415,7 +415,7 @@ def grow_tree( current_node.get_idx_right_child(), ) - new_nodes = np.array([]) + new_nodes = np.empty(2, dtype=object) for idx in range(2): idx_data_point = new_idx_data_points[idx] node_value = draw_leaf_value( From fe5331f08c0c48a8c5a6062b3a363a43eb087f74 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 31 Mar 2023 16:36:48 +0200 Subject: [PATCH 21/26] Update pymc_bart/pgbart.py Co-authored-by: Osvaldo A Martin --- pymc_bart/pgbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 288ccfe..c093d12 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -430,7 +430,7 @@ def grow_tree( value=node_value, idx_data_points=idx_data_point, ) - new_nodes = np.append(new_nodes, new_node) + new_nodes[idx] = new_node tree.grow_leaf_node(current_node, selected_predictor, split_value, index_leaf_node) tree.set_node(new_nodes[0].index, new_nodes[0]) From b64d69f5a35afc72de999b91356f9bc434558115 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 31 Mar 2023 16:37:32 +0200 Subject: [PATCH 22/26] Update pymc_bart/tree.py Co-authored-by: Osvaldo A Martin --- pymc_bart/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 4daa69d..de6e824 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -130,7 +130,7 @@ def new_tree( index=0, value=leaf_node_value, idx_data_points=idx_data_points ) }, - idx_leaf_nodes=np.array([0]), + idx_leaf_nodes=[0], output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(), ) From 4e1996a6a0bf45474bc5bd0fa97efad98595b922 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 31 Mar 2023 16:48:05 +0200 Subject: [PATCH 23/26] Update pymc_bart/tree.py Co-authored-by: Osvaldo A Martin --- pymc_bart/tree.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index de6e824..230bf41 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -155,8 +155,7 @@ def get_node(self, index: int) -> Node: def set_node(self, index: int, node: Node) -> None: self.tree_structure[index] = node if node.is_leaf_node() and self.idx_leaf_nodes is not None: - # self.idx_leaf_nodes.append(index) - self.idx_leaf_nodes = np.append(self.idx_leaf_nodes, index) + self.idx_leaf_nodes.append(index) def grow_leaf_node( self, current_node: Node, selected_predictor: int, split_value: float, index_leaf_node: int From b17d27e27292cf969890163075d30c9148674b33 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 31 Mar 2023 17:14:18 +0200 Subject: [PATCH 24/26] make idx_leaf_nodes list instead of numpy array --- pymc_bart/tree.py | 16 +++++++--------- pymc_bart/utils.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 230bf41..6ce7399 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -88,29 +88,29 @@ class Tree: The dictionary's keys are integers that represent the nodes position. The dictionary's values are objects of type Node that represent the split and leaf nodes of the tree itself. - idx_leaf_nodes : Optional[npt.NDArray[np.int_]] - Array with the index of the leaf nodes of the tree. output: Optional[npt.NDArray[np.float_]] Array of shape number of observations, shape + idx_leaf_nodes : Optional[List[int]], by default None. + Array with the index of the leaf nodes of the tree. Parameters ---------- tree_structure : Dictionary of nodes - idx_leaf_nodes : Array with the index of the leaf nodes of the tree. output : Array of shape number of observations, shape + idx_leaf_nodes : List with the index of the leaf nodes of the tree. """ __slots__ = ( "tree_structure", - "idx_leaf_nodes", "output", + "idx_leaf_nodes", ) def __init__( self, tree_structure: Dict[int, Node], - idx_leaf_nodes: Optional[npt.NDArray[np.int_]], output: npt.NDArray[np.float_], + idx_leaf_nodes: Optional[List[int]] = None, ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes @@ -146,8 +146,7 @@ def copy(self) -> "Tree": for k, v in self.tree_structure.items() } idx_leaf_nodes = self.idx_leaf_nodes.copy() if self.idx_leaf_nodes is not None else None - output = self.output.copy() if self.output is not None else None - return Tree(tree_structure=tree, idx_leaf_nodes=idx_leaf_nodes, output=output) + return Tree(tree_structure=tree, idx_leaf_nodes=idx_leaf_nodes, output=self.output) def get_node(self, index: int) -> Node: return self.tree_structure[index] @@ -164,8 +163,7 @@ def grow_leaf_node( current_node.idx_split_variable = selected_predictor current_node.idx_data_points = None if self.idx_leaf_nodes is not None: - # self.idx_leaf_nodes.remove(index_leaf_node) - self.idx_leaf_nodes = np.setdiff1d(self.idx_leaf_nodes, index_leaf_node) + self.idx_leaf_nodes.remove(index_leaf_node) def trim(self) -> "Tree": tree: Dict[int, Node] = { diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index f5adfce..fb67e55 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -72,7 +72,7 @@ def plot_convergence( idata: az.InferenceData, var_name: Optional[str] = None, kind: str = "ecdf", - figsize: Tuple[float, float] = None, + figsize=Tuple[float, float], ax=None, ) -> List[plt.Axes]: """ From 71dd2a1637c9d7c593ff974d5c4fe64e2d89b5a4 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 31 Mar 2023 17:18:29 +0200 Subject: [PATCH 25/26] fix plot figsize type --- pymc_bart/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index fb67e55..45b9a75 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -72,7 +72,7 @@ def plot_convergence( idata: az.InferenceData, var_name: Optional[str] = None, kind: str = "ecdf", - figsize=Tuple[float, float], + figsize: Optional[Tuple[float, float]] = None, ax=None, ) -> List[plt.Axes]: """ @@ -86,7 +86,7 @@ def plot_convergence( Name of the BART variable to plot. Defaults to None. kind : str Type of plot to display. Options are "ecdf" (default) and "kde". - figsize : Tuple[float, float] + figsize : Optional[Tuple[float, float]], by default None. Figure size. Defaults to None. ax : matplotlib axes Axes on which to plot. Defaults to None. From 407e4a6d2619c7d71d4c250ab600db61b2b9e071 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 31 Mar 2023 17:22:54 +0200 Subject: [PATCH 26/26] fix exclude type --- pymc_bart/tree.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 6ce7399..05f4301 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -187,7 +187,7 @@ def _predict(self) -> npt.NDArray[np.float_]: return output.T def predict( - self, x: npt.NDArray[np.float_], excluded: Optional[npt.NDArray[np.int_]] = None + self, x: npt.NDArray[np.float_], excluded: Optional[List[int]] = None ) -> npt.NDArray[np.float_]: """ Predict output of tree for an (un)observed point x. @@ -196,7 +196,7 @@ def predict( ---------- x : npt.NDArray[np.float_] Unobserved point - excluded: Optional[npt.NDArray[np.int_]] + excluded: Optional[List[int]] Indexes of the variables to exclude when computing predictions Returns @@ -205,14 +205,14 @@ def predict( Value of the leaf value where the unobserved point lies. """ if excluded is None: - excluded = np.array([]) + excluded = [] return self._traverse_tree(x, 0, excluded) def _traverse_tree( self, x: npt.NDArray[np.float_], node_index: int, - excluded: Optional[npt.NDArray[np.int_]] = None, + excluded: Optional[List[int]] = None, ) -> npt.NDArray[np.float_]: """ Traverse the tree starting from a particular node given an unobserved point. @@ -223,7 +223,7 @@ def _traverse_tree( Unobserved point node_index : int Index of the node to start the traversal from - excluded: Optional[npt.NDArray[np.int_]] + excluded: Optional[List[int]] Indexes of the variables to exclude when computing predictions Returns