Skip to content

Commit da5f7f0

Browse files
committed
Work-around for numpy bug in choice with size=()
1 parent c6909a2 commit da5f7f0

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

pytensor/tensor/random/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,6 +2084,10 @@ def rng_fn(self, *params):
20842084
batch_ndim = max(batch_ndim, size_ndim)
20852085

20862086
if batch_ndim == 0:
2087+
# Numpy choice fails with size=() if a.ndim > 1 is batched
2088+
# https://github.com/numpy/numpy/issues/26518
2089+
if core_shape == ():
2090+
core_shape = None
20872091
return rng.choice(a, p=p, size=core_shape, replace=False)
20882092

20892093
# Numpy choice doesn't have a concept of batch dims

tests/tensor/random/test_basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,15 @@ def test_choice_samples():
14221422
compare_sample_values(choice, pt.as_tensor_variable([1, 2, 3]), 2, replace=True)
14231423

14241424

1425+
def test_choice_scalar_size():
1426+
np.testing.assert_array_equal(
1427+
choice([[1, 2, 3]], size=(), replace=True).eval(), [1, 2, 3]
1428+
)
1429+
np.testing.assert_array_equal(
1430+
choice([[1, 2, 3]], size=(), replace=False).eval(), [1, 2, 3]
1431+
)
1432+
1433+
14251434
def test_permutation_samples():
14261435
compare_sample_values(
14271436
permutation,

0 commit comments

Comments
 (0)