2
2
3
3
import warnings
4
4
5
- import pytensor .tensor as pt
6
5
import arviz as az
7
6
import matplotlib .pyplot as plt
8
7
import numpy as np
9
8
import numpy .typing as npt
9
+ import pytensor .tensor as pt
10
10
from pytensor .tensor .var import Variable
11
11
from scipy .interpolate import griddata
12
12
from scipy .signal import savgol_filter
@@ -22,7 +22,7 @@ def _sample_posterior(
22
22
all_trees : List [List [Tree ]],
23
23
X : TensorLike ,
24
24
rng : np .random .Generator ,
25
- size = Optional [Union [int , Tuple [int , ...]]],
25
+ size : Optional [Union [int , Tuple [int , ...]]] = None ,
26
26
excluded : Optional [List [int ]] = None ,
27
27
) -> npt .NDArray [np .float_ ]:
28
28
"""
@@ -46,12 +46,14 @@ def _sample_posterior(
46
46
X = X .eval ()
47
47
48
48
if size is None :
49
- size = ()
49
+ size_iter : Union [ List , Tuple ] = ()
50
50
elif isinstance (size , int ):
51
- size = [size ]
51
+ size_iter = [size ]
52
+ else :
53
+ size_iter = size
52
54
53
55
flatten_size = 1
54
- for s in size :
56
+ for s in size_iter :
55
57
flatten_size *= s
56
58
57
59
idx = rng .integers (0 , len (stacked_trees ), size = flatten_size )
@@ -62,7 +64,7 @@ def _sample_posterior(
62
64
for ind , p in enumerate (pred ):
63
65
for tree in stacked_trees [idx [ind ]]:
64
66
p += np .vstack ([tree .predict (x , excluded ) for x in X ])
65
- pred .reshape ((* size , shape , - 1 ))
67
+ pred .reshape ((* size_iter , shape , - 1 ))
66
68
return pred
67
69
68
70
0 commit comments