diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 84a8fb307f..e20bd255fb 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -71,9 +71,7 @@ def compare_jax_and_py( if must_be_device_array: if isinstance(jax_res, list): - assert all( - isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res - ) + assert all(isinstance(res, jax.Array) for res in jax_res) else: assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) @@ -146,13 +144,13 @@ def test_shared(): pytensor_jax_fn = function([], a, mode="JAX") jax_res = pytensor_jax_fn() - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + assert isinstance(jax_res, jax.Array) np.testing.assert_allclose(jax_res, a.get_value()) pytensor_jax_fn = function([], a * 2, mode="JAX") jax_res = pytensor_jax_fn() - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + assert isinstance(jax_res, jax.Array) np.testing.assert_allclose(jax_res, a.get_value() * 2) # Changed the shared value and make sure that the JAX-compiled @@ -161,7 +159,7 @@ def test_shared(): a.set_value(new_a_value) jax_res = pytensor_jax_fn() - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + assert isinstance(jax_res, jax.Array) np.testing.assert_allclose(jax_res, new_a_value * 2)