Skip to content

Commit 58840ba

Browse files
committed
Fix JAX test check
1 parent 0824dba commit 58840ba

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/link/jax/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def compare_jax_and_py(
7676
if isinstance(jax_res, list):
7777
assert all(isinstance(res, jax.Array) for res in jax_res)
7878
else:
79-
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
79+
assert isinstance(jax_res, jax.Array)
8080

8181
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
8282
py_res = pytensor_py_fn(*test_inputs)

0 commit comments

Comments
 (0)