Skip to content

Commit 66d85f4

Browse files
committed
Fix bug in supp_shape_from_ref_param_shape
1 parent 9ad29b1 commit 66d85f4

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

pytensor/tensor/random/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def supp_shape_from_ref_param_shape(
323323
raise ValueError("ndim_supp must be greater than 0")
324324
if param_shapes is not None:
325325
ref_param = param_shapes[ref_param_idx]
326-
return (ref_param[-ndim_supp],)
326+
return tuple(ref_param[i] for i in range(-ndim_supp, 0))
327327
else:
328328
ref_param = dist_params[ref_param_idx]
329329
if ref_param.ndim < ndim_supp:

tests/tensor/random/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,11 @@ def test_supp_shape_from_ref_param_shape():
313313
ref_param_idx=1,
314314
)
315315
assert res == (3, 4)
316+
317+
res = supp_shape_from_ref_param_shape(
318+
ndim_supp=2,
319+
dist_params=(np.array([1, 2]), np.ones((2, 3, 4))),
320+
param_shapes=((2,), (2, 3, 4)),
321+
ref_param_idx=1,
322+
)
323+
assert res == (3, 4)

0 commit comments

Comments
 (0)