Closed
Description
Description
PyMC tests are failing: pymc-devs/pymc#7740
Reproducible example:
import pytensor.tensor as pt
p = pt.eye(3)
rv = pt.random.multinomial(n=5, p=p)
rv.eval(mode="JAX")
# Array([[ 5., nan, nan],
# [ 0., 5., 0.],
# [ 0., 0., 5.]], dtype=float64)
I guess it could be a problem with binomial with p=0?
CC @educhesne