Closed
Description
Issue with current documentation:
While reproducing this example, I encountered the following error:
AssertionError: (Eye{dtype='float64'}(<Scalar(float64, shape=())>, <Scalar(float64, shape=())>, 0), 'float64')
from test_jax_Eye()
Complete code to reproduce the error:
import jax.numpy as jnp
from pytensor.tensor.basic import Eye
from pytensor.link.jax.dispatch import jax_funcify
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt
@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
dtype = op.dtype
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye
def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
x_at = pt.scalar()
eye_var = pt.eye(x_at)
out_fg = FunctionGraph(outputs=[eye_var])
compare_jax_and_py(out_fg, [3])
test_jax_Eye()
Idea or request for content:
Instead of passing x_at
to pt.eye
, an integer can be used, like so:
pytensor/tests/link/jax/test_tensor_basic.py
Lines 207 to 212 in e8693bd