Skip to content

Commit 18554ab

Browse files
committed
add pylint fixes
1 parent 4a2fe06 commit 18554ab

File tree

3 files changed

+33
-26
lines changed

3 files changed

+33
-26
lines changed

pymc_bart/pgbart.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727

2828
from pymc_bart.bart import BARTRV
2929
from pymc_bart.split_rules import ContinuousSplitRule
30-
from pymc_bart.tree import Node, Tree, get_depth, get_idx_left_child, get_idx_right_child
30+
from pymc_bart.tree import (
31+
Node,
32+
Tree,
33+
get_depth,
34+
get_idx_left_child,
35+
get_idx_right_child,
36+
)
3137

3238

3339
class ParticleTree:
@@ -110,7 +116,7 @@ class PGBART(ArrayStepShared):
110116
generates_stats = True
111117
stats_dtypes = [{"variable_inclusion": object, "tune": bool}]
112118

113-
def __init__(
119+
def __init__( # noqa: PLR0915
114120
self,
115121
vars=None, # pylint: disable=redefined-builtin
116122
num_particles: int = 10,
@@ -544,11 +550,10 @@ def draw_leaf_value(
544550

545551
if y_mu_pred.size == 1:
546552
mu_mean = np.full(shape, y_mu_pred.item() / m) + norm
553+
elif y_mu_pred.size < 3 or response == "constant":
554+
mu_mean = fast_mean(y_mu_pred) / m + norm
547555
else:
548-
if y_mu_pred.size < 3 or response == "constant":
549-
mu_mean = fast_mean(y_mu_pred) / m + norm
550-
else:
551-
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m, norm=norm)
556+
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m, norm=norm)
552557

553558
return mu_mean, linear_params
554559

pymc_bart/utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,14 @@ def identity(x):
279279
if var in var_discrete:
280280
axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean)
281281
axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha)
282+
elif smooth:
283+
x_data, y_data = _smooth_mean(new_x, p_di, "ice", smooth_kwargs)
284+
axes[count].plot(x_data, y_data.mean(1), color=color_mean)
285+
axes[count].plot(x_data, y_data, color=color, alpha=alpha)
282286
else:
283-
if smooth:
284-
x_data, y_data = _smooth_mean(new_x, p_di, "ice", smooth_kwargs)
285-
axes[count].plot(x_data, y_data.mean(1), color=color_mean)
286-
axes[count].plot(x_data, y_data, color=color, alpha=alpha)
287-
else:
288-
idx = np.argsort(new_x)
289-
axes[count].plot(new_x[idx], p_di.mean(0)[idx], color=color_mean)
290-
axes[count].plot(new_x[idx], p_di.T[idx], color=color, alpha=alpha)
287+
idx = np.argsort(new_x)
288+
axes[count].plot(new_x[idx], p_di.mean(0)[idx], color=color_mean)
289+
axes[count].plot(new_x[idx], p_di.T[idx], color=color, alpha=alpha)
291290
axes[count].set_xlabel(x_labels[i_var])
292291

293292
count += 1
@@ -515,13 +514,12 @@ def _get_axes(
515514
for i in range(n_plots, len(axes)):
516515
fig.delaxes(axes[i])
517516
axes = axes[:n_plots]
517+
elif isinstance(ax, np.ndarray):
518+
axes = ax
519+
fig = ax[0].get_figure()
518520
else:
519-
if isinstance(ax, np.ndarray):
520-
axes = ax
521-
fig = ax[0].get_figure()
522-
else:
523-
axes = [ax]
524-
fig = ax.get_figure() # type: ignore
521+
axes = [ax]
522+
fig = ax.get_figure() # type: ignore
525523

526524
return fig, axes, shape
527525

@@ -694,7 +692,7 @@ def _smooth_mean(
694692
return x_data, y_data
695693

696694

697-
def plot_variable_importance(
695+
def plot_variable_importance( # noqa: PLR0915
698696
idata: az.InferenceData,
699697
bartrv: Variable,
700698
X: npt.NDArray[np.float_],

pyproject.toml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
[tool.pytest.ini_options]
22
minversion = "6.0"
3-
xfail_strict=true
4-
addopts = [
5-
"-vv",
6-
"--color=yes",
7-
]
3+
xfail_strict = true
4+
addopts = ["-vv", "--color=yes"]
85

96
[tool.ruff]
107
line-length = 100
118

129
[tool.ruff.lint]
1310
select = ["E", "F", "I", "PL", "UP", "W"]
1411
ignore-init-module-imports = true
12+
ignore = [
13+
"PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons.
14+
]
15+
16+
[tool.ruff.lint.pylint]
17+
max-args = 19
18+
max-branches = 15
1519

1620
[tool.ruff.extend-per-file-ignores]
1721
"docs/conf.py" = ["E501", "F541"]

0 commit comments

Comments
 (0)