We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0824dba commit 58840baCopy full SHA for 58840ba
tests/link/jax/test_basic.py
@@ -76,7 +76,7 @@ def compare_jax_and_py(
76
if isinstance(jax_res, list):
77
assert all(isinstance(res, jax.Array) for res in jax_res)
78
else:
79
- assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
+ assert isinstance(jax_res, jax.Array)
80
81
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
82
py_res = pytensor_py_fn(*test_inputs)
0 commit comments