diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 37f9362ed1..b298492915 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -412,7 +412,9 @@ def _binomial_sample_fn(carry, p_rng): remaining_n, remaining_p = carry p, rng = p_rng samples = jnp.where( - p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p) + remaining_n == 0, + 0, + jax.random.binomial(rng, remaining_n, p / remaining_p), ) remaining_n -= samples remaining_p -= p