diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 9fd2d6d..847c89a 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -9,7 +9,7 @@ assignees: '' ## Tell us about it -The more specific the better. +The more specific the better. ## Thoughts on implementation diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fea66fd..6c2c191 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,17 +36,6 @@ jobs: python --version conda list pip freeze - - name: Run linters - shell: bash -l {0} - run: | - python -m black pymc_bart --check - echo "Success!" - echo "Checking code style with pylint..." - python -m pylint pymc_bart/ - - name: Run Mypy - shell: bash -l {0} - run: | - python -m mypy pymc_bart - name: Run tests shell: bash -l {0} run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5432a77..0962c42 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,30 @@ +ci: + autofix_commit_msg: | + [pre-commit.ci] auto fixes from pre-commit.com hooks + + for more information, see https://pre-commit.ci + autofix_prs: true + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" + autoupdate_schedule: weekly + skip: [] + submodules: false + repos: - - repo: local + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.14 + hooks: + - id: ruff + args: ["--fix", "--show-source"] + - id: ruff-format + args: ["--line-length=100"] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 hooks: - - id: black - name: black - entry: black - language: system - types: [python] - files: ^pymc_bart/ - - id: pylint - name: pylint - entry: pylint --rcfile=.pylintrc - language: system - types: [python] - files: ^pymc_bart/ - id: mypy - name: mypy - entry: mypy - language: system - types: [python] + args: [--ignore-missing-imports] files: ^pymc_bart/ + additional_dependencies: [numpy<1.25.0, pandas-stubs] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 305bbb0..0000000 --- a/.pylintrc +++ /dev/null @@ -1,494 +0,0 @@ -[MASTER] - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code -extension-pkg-whitelist= - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=CVS,tests - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. -jobs=1 - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# Specify a configuration file. -#rcfile= - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=missing-docstring, - no-else-return, - no-member, - len-as-condition, - too-many-arguments, - too-many-locals, - too-many-branches, - too-many-statements, - too-many-instance-attributes, - too-few-public-methods, - import-error, - protected-access - - - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio).You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=optparse.Values,sys.exit - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=past.builtins,future.builtins - - -[BASIC] - -# Naming style matching correct argument names -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style -#argument-rgx= - -# Naming style matching correct attribute names -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Naming style matching correct class attribute names -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style -#class-attribute-rgx= - -# Naming style matching correct class names -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming-style -#class-rgx= - -# Naming style matching correct constant names -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma -good-names=i, - j, - m, - p, - k, - s, - x, - y, - X, - Y, - ax, - op, - pt, - p0, - p1, - rv, - fake_X, - new_X, - new_y, - a, - b, - n, - - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# Naming style matching correct inline iteration names -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style -#inlinevar-rgx= - -# Naming style matching correct method names -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style -#method-rgx= - -# Naming style matching correct module names -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style -#variable-rgx= - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=50 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local,netCDF4 - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make, - _get_frozen, - _get_lines, - _update_rv_frozen, - _update, - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[IMPORTS] - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - - -[DESIGN] - -# Maximum number of arguments for function / method -max-args=10 - -# Maximum number of attributes for a class (see R0902). -max-attributes=10 - -# Maximum number of boolean expressions in a if statement -max-bool-expr=5 - -# Maximum number of branch for function / method body -max-branches=12 - -# Maximum number of locals for function / method body -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body -max-returns=6 - -# Maximum number of statements in function / method body -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=Exception diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 287d4ca..399b154 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,7 +13,70 @@ and including useful supporting information. ## Contributing code Thanks for your interest in contributing code to pymc_bart! -* If this is your first time contributing to a project on GitHub, please read through our step by step guide to contributing to pymc_bart +**If this is your first time contributing to a project on GitHub, please read through our step by step guide to contributing to pymc_bart** + +### Feature Branch + +1. From the fork of the pymc_bart repository, create a new branch for your feature. + +```bash +git checkout -b feature_branch_name +``` + +2. Make your changes to the code base. + +3.Add and commit your changes. + +```bash +git add my_modified_file.py +git commit -m "Added a new feature" +``` + +4. Push your changes to your fork of the pymc_bart repository. + +```bash +git push origin feature_branch_name +``` + +### Code Style + +The repository has some code style checks in place. This will happen on every commit of a pull request. If you want to run the checks locally, you can do so by running the following command from the root of the repository: + +0. Create a virtual environment (optional, but strongly recommended) + +1. Install pre-commit + +```bash +pip install pre-commit +``` + +2. Set up pre-commit + +```bash +pre-commit install +``` + +3. Run the complete pre-commit hook to check specific files: + +```bash +pre-commit run --files pymc_bart/tree.py +``` + +or all files: + +```bash +pre-commit run --all-files +``` + +**Once you commit something the pre-commit hook will run all the checks**! + +You can skip this (for example when is WIP) by adding a flag (`-n` means no-verify) + +```bash +git commit -m"my message" -n +``` + +**Remark:** One can, of course, install `ruff` in the Python environment to enable auto-format (for example in VS Code), but this is not strictly necessary. The specific versions of` ruff` and `mypy` must be only specified in `.pre-commit-config.yaml`. It should be the only source of truth! Hence, if you want to install them locally make sure you use the same versions (revisions `rev` in the config file) as in the config file. ### Adding new features If you are interested in adding a new feature to pymc_bart, diff --git a/docs/conf.py b/docs/conf.py index 9f1b249..236fd76 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,5 @@ -import os, sys +import os +import sys from pathlib import Path # -- Project information ----------------------------------------------------- @@ -42,10 +43,18 @@ if os.path.exists(file): os.remove(file) -os.system("wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/bart/bart_introduction.ipynb -P examples") -os.system("wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/bart/bart_quantile_regression.ipynb -P examples") -os.system("wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/bart/bart_heteroscedasticity.ipynb -P examples") -os.system("wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/references.bib -P examples") +os.system( + "wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/bart/bart_introduction.ipynb -P examples" +) +os.system( + "wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/bart/bart_quantile_regression.ipynb -P examples" +) +os.system( + "wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/bart/bart_heteroscedasticity.ipynb -P examples" +) +os.system( + "wget https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/references.bib -P examples" +) # bibtex config bibtex_bibfiles = ["examples/references.bib"] @@ -63,11 +72,23 @@ "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink", "donate"], "navbar_start": ["navbar-logo"], "icon_links": [ - { "url": "https://github.com/pymc-devs/pymc-bart", "icon": "fa-brands fa-github", "name": "GitHub" }, - { "url": "https://twitter.com/pymc_devs/", "icon": "fa-brands fa-twitter", "name": "Twitter" }, - { "url": "https://www.youtube.com/c/PyMCDevelopers", "icon": "fa-brands fa-youtube", "name": "YouTube" }, - { "url": "https://discourse.pymc.io", "icon": "fa-brands fa-discourse", "name": "Discourse" }, - ] + { + "url": "https://github.com/pymc-devs/pymc-bart", + "icon": "fa-brands fa-github", + "name": "GitHub", + }, + { + "url": "https://twitter.com/pymc_devs/", + "icon": "fa-brands fa-twitter", + "name": "Twitter", + }, + { + "url": "https://www.youtube.com/c/PyMCDevelopers", + "icon": "fa-brands fa-youtube", + "name": "YouTube", + }, + {"url": "https://discourse.pymc.io", "icon": "fa-brands fa-discourse", "name": "Discourse"}, + ], } version = os.environ.get("READTHEDOCS_VERSION", "") diff --git a/docs/index.rst b/docs/index.rst index 4ade70b..4b1dd0e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -43,7 +43,7 @@ Assuming a standard Python environment is installed on your machine, PyMC-BART i .. code-block:: bash - conda install -c conda-forge pymc-bart + conda install -c conda-forge pymc-bart **Development** diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 6a0d7b6..2e98b7d 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -18,13 +18,24 @@ from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule from pymc_bart.utils import ( plot_convergence, - plot_pdp, - plot_ice, plot_dependence, + plot_ice, + plot_pdp, plot_variable_importance, ) -__all__ = ["BART", "PGBART"] +__all__ = [ + "BART", + "PGBART", + "ContinuousSplitRule", + "OneHotSplitRule", + "SubsetSplitRule", + "plot_convergence", + "plot_dependence", + "plot_ice", + "plot_pdp", + "plot_variable_importance", +] __version__ = "0.5.7" diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 83896dd..988485a 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -43,9 +43,7 @@ class BARTRV(RandomVariable): _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}") all_trees = List[List[List[Tree]]] - def _supp_shape_from_params( - self, dist_params, rep_param_idx=1, param_shapes=None - ): # pylint: disable=arguments-renamed + def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed return dist_params[0].shape[:1] @classmethod diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 32d545d..b4cf23b 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -13,8 +13,9 @@ # limitations under the License. from typing import List, Optional, Tuple, Union -import numpy.typing as npt + import numpy as np +import numpy.typing as npt from numba import njit from pymc.model import Model, modelcontext from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements @@ -25,8 +26,14 @@ from pytensor.tensor.var import Variable from pymc_bart.bart import BARTRV -from pymc_bart.tree import Node, Tree, get_idx_left_child, get_idx_right_child, get_depth from pymc_bart.split_rules import ContinuousSplitRule +from pymc_bart.tree import ( + Node, + Tree, + get_depth, + get_idx_left_child, + get_idx_right_child, +) class ParticleTree: @@ -109,7 +116,7 @@ class PGBART(ArrayStepShared): generates_stats = True stats_dtypes = [{"variable_inclusion": object, "tune": bool}] - def __init__( + def __init__( # noqa: PLR0915 self, vars=None, # pylint: disable=redefined-builtin num_particles: int = 10, @@ -543,11 +550,10 @@ def draw_leaf_value( if y_mu_pred.size == 1: mu_mean = np.full(shape, y_mu_pred.item() / m) + norm + elif y_mu_pred.size < 3 or response == "constant": + mu_mean = fast_mean(y_mu_pred) / m + norm else: - if y_mu_pred.size < 3 or response == "constant": - mu_mean = fast_mean(y_mu_pred) / m + norm - else: - mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m, norm=norm) + mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m, norm=norm) return mu_mean, linear_params diff --git a/pymc_bart/split_rules.py b/pymc_bart/split_rules.py index 976c08a..5a3f6cb 100644 --- a/pymc_bart/split_rules.py +++ b/pymc_bart/split_rules.py @@ -13,8 +13,9 @@ # limitations under the License. from abc import abstractmethod -from numba import njit + import numpy as np +from numba import njit class SplitRule: diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 6c6c297..c9bac2d 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -302,9 +302,10 @@ def _traverse_tree( ) else: idx_split_variable = node.idx_split_variable - left_node_index, right_node_index = get_idx_left_child( - node_index - ), get_idx_right_child(node_index) + left_node_index, right_node_index = ( + get_idx_left_child(node_index), + get_idx_right_child(node_index), + ) if excluded is not None and idx_split_variable in excluded: prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable)) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 47acedc..39e0620 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -279,15 +279,14 @@ def identity(x): if var in var_discrete: axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean) axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha) + elif smooth: + x_data, y_data = _smooth_mean(new_x, p_di, "ice", smooth_kwargs) + axes[count].plot(x_data, y_data.mean(1), color=color_mean) + axes[count].plot(x_data, y_data, color=color, alpha=alpha) else: - if smooth: - x_data, y_data = _smooth_mean(new_x, p_di, "ice", smooth_kwargs) - axes[count].plot(x_data, y_data.mean(1), color=color_mean) - axes[count].plot(x_data, y_data, color=color, alpha=alpha) - else: - idx = np.argsort(new_x) - axes[count].plot(new_x[idx], p_di.mean(0)[idx], color=color_mean) - axes[count].plot(new_x[idx], p_di.T[idx], color=color, alpha=alpha) + idx = np.argsort(new_x) + axes[count].plot(new_x[idx], p_di.mean(0)[idx], color=color_mean) + axes[count].plot(new_x[idx], p_di.T[idx], color=color, alpha=alpha) axes[count].set_xlabel(x_labels[i_var]) count += 1 @@ -515,13 +514,12 @@ def _get_axes( for i in range(n_plots, len(axes)): fig.delaxes(axes[i]) axes = axes[:n_plots] + elif isinstance(ax, np.ndarray): + axes = ax + fig = ax[0].get_figure() else: - if isinstance(ax, np.ndarray): - axes = ax - fig = ax[0].get_figure() - else: - axes = [ax] - fig = ax.get_figure() # type: ignore + axes = [ax] + fig = ax.get_figure() # type: ignore return fig, axes, shape @@ -694,7 +692,7 @@ def _smooth_mean( return x_data, y_data -def plot_variable_importance( +def plot_variable_importance( # noqa: PLR0915 idata: az.InferenceData, bartrv: Variable, X: npt.NDArray[np.float_], @@ -836,8 +834,9 @@ def plot_variable_importance( r_2 = np.zeros(samples) for j in range(samples): r_2[j] = ( - pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] - ) ** 2 + (pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]) + ** 2 + ) mean_r_2 = np.mean(r_2, dtype=float) # Identify the least important combination of variables # based on the maximum mean squared Pearson correlation diff --git a/pyproject.toml b/pyproject.toml index 8f1bbef..165ed67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,26 @@ [tool.pytest.ini_options] minversion = "6.0" -xfail_strict=true -addopts = [ - "-vv", - "--color=yes", -] +xfail_strict = true +addopts = ["-vv", "--color=yes"] -[tool.black] +[tool.ruff] line-length = 100 +[tool.ruff.lint] +select = ["E", "F", "I", "PL", "UP", "W"] +ignore-init-module-imports = true +ignore = [ + "PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons. +] + +[tool.ruff.lint.pylint] +max-args = 19 +max-branches = 15 + +[tool.ruff.extend-per-file-ignores] +"docs/conf.py" = ["E501", "F541"] +"tests/test_*.py" = ["F841"] + [tool.coverage.report] exclude_lines = [ "pragma: nocover", diff --git a/requirements-dev.txt b/requirements-dev.txt index 8913d99..26118bf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,4 @@ -black==22.3.0 click==8.0.4 -mypy==1.3.0 -pandas-stubs==1.5.3.230304 pre-commit pylint==2.17.4 pytest-cov>=2.6.1 diff --git a/setupegg.py b/setupegg.py index 14aae29..6c6a13d 100755 --- a/setupegg.py +++ b/setupegg.py @@ -17,7 +17,5 @@ A setup.py script to use setuptools, which gives egg goodness, etc. """ -from setuptools import setup - with open("setup.py") as s: exec(s.read()) diff --git a/tests/test_bart.py b/tests/test_bart.py index 0c229f6..dfbd86f 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -1,7 +1,6 @@ import numpy as np import pymc as pm import pytest -from numpy.random import RandomState from numpy.testing import assert_almost_equal, assert_array_equal from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import joint_logp @@ -230,7 +229,6 @@ def test_bart_moment(size, expected): ids=["continuous", "one-hot", "subset", "separate-trees"], ) def test_categorical_model(separate_trees, split_rule): - Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]) X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1) diff --git a/tests/test_pgbart.py b/tests/test_pgbart.py index 419844d..4cf4188 100644 --- a/tests/test_pgbart.py +++ b/tests/test_pgbart.py @@ -1,15 +1,16 @@ from unittest import TestCase -import pytest + import numpy as np import pymc as pm +import pytest import pymc_bart as pmb from pymc_bart.pgbart import ( NormalSampler, UniformSampler, discrete_uniform_sampler, - fast_mean, fast_linear_fit, + fast_mean, ) diff --git a/tests/test_split_rules.py b/tests/test_split_rules.py index e84810d..8fd1ebc 100644 --- a/tests/test_split_rules.py +++ b/tests/test_split_rules.py @@ -1,7 +1,7 @@ import numpy as np +import pytest from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule -import pytest @pytest.mark.parametrize( @@ -10,7 +10,6 @@ ids=["continuous", "one_hot", "subset"], ) def test_split_rule(Rule): - # Should return None if only one available value to pick from assert Rule.get_split_value(np.zeros(1)) is None diff --git a/tests/test_tree.py b/tests/test_tree.py index 453a75e..ca2de3b 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -1,6 +1,6 @@ import numpy as np -from pymc_bart.tree import Node, get_idx_left_child, get_idx_right_child, get_depth +from pymc_bart.tree import Node, get_depth, get_idx_left_child, get_idx_right_child def test_split_node():