From 3debfdc3989a45fcc4277ab1661874b3499b720a Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 1 Apr 2025 15:15:33 +0200 Subject: [PATCH] More stable fix for JAX Multinomial --- pytensor/link/jax/dispatch/random.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 37f9362ed1..b298492915 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -412,7 +412,9 @@ def _binomial_sample_fn(carry, p_rng): remaining_n, remaining_p = carry p, rng = p_rng samples = jnp.where( - p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p) + remaining_n == 0, + 0, + jax.random.binomial(rng, remaining_n, p / remaining_p), ) remaining_n -= samples remaining_p -= p