Skip to content

Commit e757477

Browse files
author
juanitorduz
committed
fix some hints
1 parent 372b581 commit e757477

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pymc_bart/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import warnings
44

5-
import pytensor.tensor as pt
65
import arviz as az
76
import matplotlib.pyplot as plt
87
import numpy as np
98
import numpy.typing as npt
9+
import pytensor.tensor as pt
1010
from pytensor.tensor.var import Variable
1111
from scipy.interpolate import griddata
1212
from scipy.signal import savgol_filter
@@ -22,7 +22,7 @@ def _sample_posterior(
2222
all_trees: List[List[Tree]],
2323
X: TensorLike,
2424
rng: np.random.Generator,
25-
size=Optional[Union[int, Tuple[int, ...]]],
25+
size: Optional[Union[int, Tuple[int, ...]]] = None,
2626
excluded: Optional[List[int]] = None,
2727
) -> npt.NDArray[np.float_]:
2828
"""
@@ -46,12 +46,14 @@ def _sample_posterior(
4646
X = X.eval()
4747

4848
if size is None:
49-
size = ()
49+
size_iter: Union[List, Tuple] = ()
5050
elif isinstance(size, int):
51-
size = [size]
51+
size_iter = [size]
52+
else:
53+
size_iter = size
5254

5355
flatten_size = 1
54-
for s in size:
56+
for s in size_iter:
5557
flatten_size *= s
5658

5759
idx = rng.integers(0, len(stacked_trees), size=flatten_size)
@@ -62,7 +64,7 @@ def _sample_posterior(
6264
for ind, p in enumerate(pred):
6365
for tree in stacked_trees[idx[ind]]:
6466
p += np.vstack([tree.predict(x, excluded) for x in X])
65-
pred.reshape((*size, shape, -1))
67+
pred.reshape((*size_iter, shape, -1))
6668
return pred
6769

6870

0 commit comments

Comments
 (0)