Skip to content

Commit 34cf7c1

Browse files
committed
Fix nan in jax implementation of Multinomial
1 parent 0b56ed9 commit 34cf7c1

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def sample_fn(rng_key, size, dtype, n, p):
411411
def _binomial_sample_fn(carry, p_rng):
412412
s, rho = carry
413413
p, rng = p_rng
414-
samples = jax.random.binomial(rng, s, p / rho)
414+
samples = jnp.where(p == 0, 0, jax.random.binomial(rng, s, p / rho))
415415
s = s - samples
416416
rho = rho - p
417417
return ((s, rho), samples)

tests/link/jax/test_random.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,12 @@ def test_multinomial():
733733
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
734734
)
735735

736+
# Test with p=0
737+
g = pt.random.multinomial(n=5, p=pt.eye(4))
738+
g_fn = compile_random_function([], g, mode="JAX")
739+
samples = g_fn()
740+
np.testing.assert_array_equal(samples, np.eye(4) * 5)
741+
736742

737743
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
738744
def test_vonmises_mu_outside_circle():

0 commit comments

Comments
 (0)