Skip to content

Nans in JAX multinomial dispatch #1327

Closed
@ricardoV94

Description

@ricardoV94

Description

PyMC tests are failing: pymc-devs/pymc#7740

Reproducible example:

import pytensor.tensor as pt

p = pt.eye(3)
rv = pt.random.multinomial(n=5, p=p)
rv.eval(mode="JAX")
# Array([[ 5., nan, nan],
#        [ 0.,  5.,  0.],
#        [ 0.,  0.,  5.]], dtype=float64)

I guess it could be a problem with binomial with p=0?

CC @educhesne

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions